aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-14 09:21:08 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-14 09:21:08 -0700
commit41aaed7751690b0b3137dad2620656a698b3ceae (patch)
tree00fc1a7f6be0c3968f3e674a65ca4907110ddf2d
parentc26c5e1217944448f1f4c2b97626fc4d7d6406d3 (diff)
parent95338704198205c1bdec1e344e103f1daf05df68 (diff)
Merge branch 'master' into avijit/add-cpu-backend
-rw-r--r--RELEASE.md6
-rw-r--r--configure.py2
-rw-r--r--tensorflow/c/c_api.cc1
-rw-r--r--tensorflow/c/c_api_experimental.cc1
-rw-r--r--tensorflow/c/c_api_function.cc1
-rw-r--r--tensorflow/cc/framework/ops.h2
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h1
-rw-r--r--tensorflow/compiler/aot/tests/BUILD1
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc23
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc6
-rw-r--r--tensorflow/compiler/jit/BUILD7
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc17
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h6
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc360
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h60
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc346
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc19
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc2
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc49
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h1
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc6
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h8
-rw-r--r--tensorflow/compiler/tests/BUILD6
-rw-r--r--tensorflow/compiler/tests/adam_test.py6
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl165
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py190
-rw-r--r--tensorflow/compiler/tests/reshape_op_test.py2
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py39
-rw-r--r--tensorflow/compiler/tf2xla/BUILD21
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc74
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h13
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc147
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h9
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc25
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc43
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc17
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc3
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.h13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc60
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc26
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc88
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc29
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py8
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc18
-rw-r--r--tensorflow/compiler/tf2xla/test_util.cc8
-rw-r--r--tensorflow/compiler/tf2xla/test_util.h16
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc102
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h63
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc16
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc226
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc24
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h1
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/client.cc12
-rw-r--r--tensorflow/compiler/xla/client/client.h10
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc4
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc20
-rw-r--r--tensorflow/compiler/xla/client/local_client.h10
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc46
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h45
-rw-r--r--tensorflow/compiler/xla/literal.cc149
-rw-r--r--tensorflow/compiler/xla/literal.h58
-rw-r--r--tensorflow/compiler/xla/literal_test.cc913
-rw-r--r--tensorflow/compiler/xla/literal_util.cc273
-rw-r--r--tensorflow/compiler/xla/literal_util.h228
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc15
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h3
-rw-r--r--tensorflow/compiler/xla/protobuf_util.cc29
-rw-r--r--tensorflow/compiler/xla/protobuf_util.h4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc20
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h8
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i18
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc7
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/reference_util.cc75
-rw-r--r--tensorflow/compiler/xla/reference_util.h50
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc50
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/BUILD74
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc323
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc12
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc6
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc53
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc129
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc167
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc81
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h44
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc38
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc61
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc253
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h57
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc503
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h203
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling.cc)20
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h (renamed from tensorflow/compiler/xla/service/hlo_scheduling.h)38
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling_test.cc)28
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.h81
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc142
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc88
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc79
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h14
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc260
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h34
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc107
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc19
-rw-r--r--tensorflow/compiler/xla/shape_tree.h9
-rw-r--r--tensorflow/compiler/xla/shape_util.cc13
-rw-r--r--tensorflow/compiler/xla/shape_util.h4
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc256
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc128
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc89
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc71
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h101
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc148
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc60
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc41
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc130
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc161
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc23
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h12
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h30
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc43
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc253
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc2
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h3
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc150
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc87
-rw-r--r--tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc149
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc62
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc123
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc192
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc308
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc172
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc74
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h12
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc152
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc66
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc11
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc17
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc2
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc17
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc4
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc16
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.h10
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc36
-rw-r--r--tensorflow/compiler/xrt/xrt_state.cc2
-rw-r--r--tensorflow/compiler/xrt/xrt_state.h2
-rw-r--r--tensorflow/contrib/BUILD8
-rw-r--r--tensorflow/contrib/autograph/BUILD8
-rw-r--r--tensorflow/contrib/autograph/README.md7
-rw-r--r--tensorflow/contrib/autograph/__init__.py50
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc10
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc9
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc10
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py2
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py10
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py32
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py8
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py40
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py5
-rw-r--r--tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py4
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py4
-rw-r--r--tensorflow/contrib/cmake/README.md4
-rw-r--r--tensorflow/contrib/cmake/external/png.cmake3
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py10
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py20
-rw-r--r--tensorflow/contrib/data/__init__.py15
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc3
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD19
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py54
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py276
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py123
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py36
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py26
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD54
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py177
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py)37
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py25
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py9
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py44
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py39
-rw-r--r--tensorflow/contrib/distribute/README.md2
-rw-r--r--tensorflow/contrib/distribute/python/BUILD22
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py20
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py1
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py289
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py16
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py6
-rw-r--r--tensorflow/contrib/distributions/BUILD1
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py2
-rw-r--r--tensorflow/contrib/eager/python/evaluator_test.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md2
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/BUILD25
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py54
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_test.py54
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py4
-rw-r--r--tensorflow/contrib/estimator/BUILD56
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py434
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py611
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py16
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops_test.py6
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py2
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py8
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py4
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py18
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py20
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py52
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py8
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py34
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py54
-rw-r--r--tensorflow/contrib/layers/python/layers/encoders_test.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py206
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py26
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py316
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py14
-rw-r--r--tensorflow/contrib/layers/python/layers/regularizers_test.py14
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py10
-rw-r--r--tensorflow/contrib/layers/python/layers/summaries_test.py12
-rw-r--r--tensorflow/contrib/layers/python/layers/utils_test.py24
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops_test.py46
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py6
-rw-r--r--tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py4
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py4
-rw-r--r--tensorflow/contrib/lite/build_def.bzl59
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc28
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc8
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc19
-rw-r--r--tensorflow/contrib/lite/experimental/writer/BUILD66
-rw-r--r--tensorflow/contrib/lite/experimental/writer/enum_mapping.h116
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc370
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer.cc41
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc287
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h131
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc62
-rw-r--r--tensorflow/contrib/lite/g3doc/_index.yaml5
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md91
-rw-r--r--tensorflow/contrib/lite/interpreter.cc1
-rw-r--r--tensorflow/contrib/lite/interpreter.h13
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md6
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc116
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h1077
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h28
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h86
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h42
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h3
-rw-r--r--tensorflow/contrib/lite/model.cc5
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h31
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc6
-rw-r--r--tensorflow/contrib/lite/python/convert.py43
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py12
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py22
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py11
-rw-r--r--tensorflow/contrib/lite/schema/BUILD14
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs9
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h162
-rw-r--r--tensorflow/contrib/lite/testing/BUILD7
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py34
-rw-r--r--tensorflow/contrib/lite/toco/BUILD6
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc20
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc76
-rw-r--r--tensorflow/contrib/lite/toco/model.h5
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc16
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc70
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h3
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc8
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py2
-rw-r--r--tensorflow/contrib/lite/tutorials/post_training_quant.ipynb702
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py206
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py214
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt113
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt75
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt74
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt523
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt57
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt77
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py10
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py456
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py4
-rw-r--r--tensorflow/contrib/nccl/BUILD21
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_rewrite.cc1
-rw-r--r--tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py2
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py13
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py6
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py14
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py19
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py56
-rw-r--r--tensorflow/contrib/saved_model/BUILD18
-rw-r--r--tensorflow/contrib/saved_model/__init__.py2
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/__init__.py1
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py10
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py42
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py191
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py25
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py175
-rw-r--r--tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py6
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py14
-rw-r--r--tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py8
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py46
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py6
-rw-r--r--tensorflow/contrib/specs/python/specs_test.py22
-rw-r--r--tensorflow/contrib/specs/python/summaries_test.py8
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops.cc3
-rw-r--r--tensorflow/contrib/tensorrt/BUILD31
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py319
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert_test.py293
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py6
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py28
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py16
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py9
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py23
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py22
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py22
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py6
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc5
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto10
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py13
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py9
-rw-r--r--tensorflow/core/BUILD18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt34
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt22
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc11
-rw-r--r--tensorflow/core/common_runtime/eager/context.h4
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc2
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h11
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device_test.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc4
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h200
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc1
-rw-r--r--tensorflow/core/common_runtime/single_threaded_cpu_device.h1
-rw-r--r--tensorflow/core/example/example.proto8
-rw-r--r--tensorflow/core/framework/allocator.cc9
-rw-r--r--tensorflow/core/framework/allocator.h11
-rw-r--r--tensorflow/core/framework/allocator_registry.h1
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc1
-rw-r--r--tensorflow/core/framework/dataset.h108
-rw-r--r--tensorflow/core/framework/function.cc11
-rw-r--r--tensorflow/core/framework/function.h4
-rw-r--r--tensorflow/core/framework/function_testlib.cc16
-rw-r--r--tensorflow/core/framework/function_testlib.h3
-rw-r--r--tensorflow/core/framework/model.cc396
-rw-r--r--tensorflow/core/framework/model.h396
-rw-r--r--tensorflow/core/framework/model.proto30
-rw-r--r--tensorflow/core/framework/resource_mgr.cc2
-rw-r--r--tensorflow/core/framework/resource_mgr.h6
-rw-r--r--tensorflow/core/framework/tensor.h3
-rw-r--r--tensorflow/core/framework/tensor_test.cc1
-rw-r--r--tensorflow/core/framework/tensor_util.h1
-rw-r--r--tensorflow/core/framework/types.h3
-rw-r--r--tensorflow/core/framework/variant.cc25
-rw-r--r--tensorflow/core/framework/variant.h60
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h32
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc85
-rw-r--r--tensorflow/core/framework/variant_op_registry.h216
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
-rw-r--r--tensorflow/core/framework/variant_tensor_data.cc22
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h10
-rw-r--r--tensorflow/core/framework/variant_test.cc15
-rw-r--r--tensorflow/core/graph/graph_constructor.cc4
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc199
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc385
-rw-r--r--tensorflow/core/grappler/costs/utils.cc8
-rw-r--r--tensorflow/core/grappler/costs/utils.h2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc8
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc7
-rw-r--r--tensorflow/core/grappler/inputs/utils.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD65
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc93
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.h115
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc139
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.cc167
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.h80
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info_test.cc160
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc27
-rw-r--r--tensorflow/core/grappler/utils/scc.h7
-rw-r--r--tensorflow/core/kernels/BUILD20
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD4
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/kernels/conv_3d.h43
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc11
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1324
-rw-r--r--tensorflow/core/kernels/data/BUILD14
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc66
-rw-r--r--tensorflow/core/kernels/data/captured_function.h3
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc234
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc127
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc7
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc39
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc16
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc13
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc47
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc7
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc3
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h311
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h41
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc31
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h1356
-rw-r--r--tensorflow/core/kernels/gather_functor.h1
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc3
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc4
-rw-r--r--tensorflow/core/kernels/queue_ops.cc2
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc10
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc2
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc5
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc10
-rw-r--r--tensorflow/core/kernels/split_op.cc7
-rw-r--r--tensorflow/core/kernels/stack_ops.cc26
-rw-r--r--tensorflow/core/kernels/substr_op.cc50
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc105
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc54
-rw-r--r--tensorflow/core/lib/core/stringpiece.h117
-rw-r--r--tensorflow/core/lib/io/record_reader.cc3
-rw-r--r--tensorflow/core/lib/io/record_reader.h8
-rw-r--r--tensorflow/core/lib/io/record_writer.cc15
-rw-r--r--tensorflow/core/lib/io/record_writer.h32
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc2
-rw-r--r--tensorflow/core/lib/io/table_test.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h2
-rw-r--r--tensorflow/core/lib/strings/strcat.h3
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc5
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc125
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt184
-rw-r--r--tensorflow/core/ops/dataset_ops.cc7
-rw-r--r--tensorflow/core/ops/ops.pbtxt184
-rw-r--r--tensorflow/core/ops/parsing_ops.cc7
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc7
-rw-r--r--tensorflow/core/platform/abi.cc4
-rw-r--r--tensorflow/core/platform/abi.h3
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc2
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h2
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc2
-rw-r--r--tensorflow/core/platform/cord.h26
-rw-r--r--tensorflow/core/platform/default/cord.h24
-rw-r--r--tensorflow/core/platform/env_test.cc7
-rw-r--r--tensorflow/core/platform/file_system.h8
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc2
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.cc2
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc2
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc2
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc10
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h4
-rw-r--r--tensorflow/examples/android/README.md8
-rw-r--r--tensorflow/examples/autograph/integration_tests/BUILD (renamed from tensorflow/contrib/autograph/examples/integration_tests/BUILD)0
-rw-r--r--tensorflow/examples/autograph/integration_tests/errors_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/errors_test.py)34
-rw-r--r--tensorflow/examples/autograph/integration_tests/keras_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/keras_test.py)2
-rw-r--r--tensorflow/examples/autograph/integration_tests/list_literals_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py)2
-rw-r--r--tensorflow/examples/speech_commands/freeze_test.py6
-rw-r--r--tensorflow/examples/speech_commands/input_data_test.py4
-rw-r--r--tensorflow/examples/speech_commands/label_wav_test.py2
-rw-r--r--tensorflow/examples/speech_commands/models_test.py12
-rw-r--r--tensorflow/go/op/wrappers.go116
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/autograph/BUILD31
-rw-r--r--tensorflow/python/autograph/CONTRIBUTING.md (renamed from tensorflow/contrib/autograph/CONTRIBUTING.md)9
-rw-r--r--tensorflow/python/autograph/LIMITATIONS.md (renamed from tensorflow/contrib/autograph/LIMITATIONS.md)0
-rw-r--r--tensorflow/python/autograph/README.md143
-rw-r--r--tensorflow/python/autograph/STYLE_GUIDE.md (renamed from tensorflow/contrib/autograph/STYLE_GUIDE.md)0
-rw-r--r--tensorflow/python/autograph/__init__.py68
-rw-r--r--tensorflow/python/autograph/converters/BUILD (renamed from tensorflow/contrib/autograph/converters/BUILD)54
-rw-r--r--tensorflow/python/autograph/converters/__init__.py (renamed from tensorflow/contrib/autograph/converters/__init__.py)0
-rw-r--r--tensorflow/python/autograph/converters/asserts.py (renamed from tensorflow/contrib/autograph/converters/asserts.py)4
-rw-r--r--tensorflow/python/autograph/converters/asserts_test.py (renamed from tensorflow/contrib/autograph/converters/asserts_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/break_statements.py (renamed from tensorflow/contrib/autograph/converters/break_statements.py)8
-rw-r--r--tensorflow/python/autograph/converters/break_statements_test.py (renamed from tensorflow/contrib/autograph/converters/break_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py (renamed from tensorflow/contrib/autograph/converters/builtin_functions.py)8
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py (renamed from tensorflow/contrib/autograph/converters/builtin_functions_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/call_trees.py (renamed from tensorflow/contrib/autograph/converters/call_trees.py)12
-rw-r--r--tensorflow/python/autograph/converters/call_trees_test.py (renamed from tensorflow/contrib/autograph/converters/call_trees_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions.py (renamed from tensorflow/contrib/autograph/converters/conditional_expressions.py)8
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions_test.py (renamed from tensorflow/contrib/autograph/converters/conditional_expressions_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/continue_statements.py (renamed from tensorflow/contrib/autograph/converters/continue_statements.py)8
-rw-r--r--tensorflow/python/autograph/converters/continue_statements_test.py (renamed from tensorflow/contrib/autograph/converters/continue_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/control_flow.py (renamed from tensorflow/contrib/autograph/converters/control_flow.py)12
-rw-r--r--tensorflow/python/autograph/converters/control_flow_test.py (renamed from tensorflow/contrib/autograph/converters/control_flow_test.py)6
-rw-r--r--tensorflow/python/autograph/converters/decorators.py (renamed from tensorflow/contrib/autograph/converters/decorators.py)4
-rw-r--r--tensorflow/python/autograph/converters/decorators_test.py (renamed from tensorflow/contrib/autograph/converters/decorators_test.py)16
-rw-r--r--tensorflow/python/autograph/converters/directives.py (renamed from tensorflow/contrib/autograph/converters/directives.py)6
-rw-r--r--tensorflow/python/autograph/converters/directives_test.py (renamed from tensorflow/contrib/autograph/converters/directives_test.py)12
-rw-r--r--tensorflow/python/autograph/converters/error_handlers.py (renamed from tensorflow/contrib/autograph/converters/error_handlers.py)6
-rw-r--r--tensorflow/python/autograph/converters/error_handlers_test.py (renamed from tensorflow/contrib/autograph/converters/error_handlers_test.py)10
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions.py (renamed from tensorflow/contrib/autograph/converters/list_comprehensions.py)4
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions_test.py (renamed from tensorflow/contrib/autograph/converters/list_comprehensions_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/lists.py (renamed from tensorflow/contrib/autograph/converters/lists.py)12
-rw-r--r--tensorflow/python/autograph/converters/lists_test.py (renamed from tensorflow/contrib/autograph/converters/lists_test.py)12
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py (renamed from tensorflow/contrib/autograph/converters/logical_expressions.py)29
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions_test.py (renamed from tensorflow/contrib/autograph/converters/logical_expressions_test.py)14
-rw-r--r--tensorflow/python/autograph/converters/name_scopes.py (renamed from tensorflow/contrib/autograph/converters/name_scopes.py)4
-rw-r--r--tensorflow/python/autograph/converters/name_scopes_test.py (renamed from tensorflow/contrib/autograph/converters/name_scopes_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py (renamed from tensorflow/contrib/autograph/converters/return_statements.py)10
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py (renamed from tensorflow/contrib/autograph/converters/return_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards.py (renamed from tensorflow/contrib/autograph/converters/side_effect_guards.py)12
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards_test.py (renamed from tensorflow/contrib/autograph/converters/side_effect_guards_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/slices.py (renamed from tensorflow/contrib/autograph/converters/slices.py)6
-rw-r--r--tensorflow/python/autograph/converters/slices_test.py (renamed from tensorflow/contrib/autograph/converters/slices_test.py)12
-rw-r--r--tensorflow/python/autograph/core/BUILD (renamed from tensorflow/contrib/autograph/core/BUILD)14
-rw-r--r--tensorflow/python/autograph/core/config.py (renamed from tensorflow/contrib/autograph/core/config.py)4
-rw-r--r--tensorflow/python/autograph/core/converter.py (renamed from tensorflow/contrib/autograph/core/converter.py)26
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py (renamed from tensorflow/contrib/autograph/core/converter_testing.py)18
-rw-r--r--tensorflow/python/autograph/core/errors.py (renamed from tensorflow/contrib/autograph/core/errors.py)2
-rw-r--r--tensorflow/python/autograph/core/errors_test.py (renamed from tensorflow/contrib/autograph/core/errors_test.py)4
-rw-r--r--tensorflow/python/autograph/core/naming.py (renamed from tensorflow/contrib/autograph/core/naming.py)2
-rw-r--r--tensorflow/python/autograph/core/naming_test.py (renamed from tensorflow/contrib/autograph/core/naming_test.py)2
-rw-r--r--tensorflow/python/autograph/docs/pyfunc_dtypes.md (renamed from tensorflow/contrib/autograph/docs/pyfunc_dtypes.md)2
-rw-r--r--tensorflow/python/autograph/impl/BUILD (renamed from tensorflow/contrib/autograph/impl/BUILD)14
-rw-r--r--tensorflow/python/autograph/impl/api.py (renamed from tensorflow/contrib/autograph/impl/api.py)22
-rw-r--r--tensorflow/python/autograph/impl/api_test.py (renamed from tensorflow/contrib/autograph/impl/api_test.py)10
-rw-r--r--tensorflow/python/autograph/impl/conversion.py (renamed from tensorflow/contrib/autograph/impl/conversion.py)56
-rw-r--r--tensorflow/python/autograph/impl/conversion_test.py (renamed from tensorflow/contrib/autograph/impl/conversion_test.py)10
-rw-r--r--tensorflow/python/autograph/lang/BUILD (renamed from tensorflow/contrib/autograph/lang/BUILD)2
-rw-r--r--tensorflow/python/autograph/lang/directives.py (renamed from tensorflow/contrib/autograph/lang/directives.py)0
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py (renamed from tensorflow/contrib/autograph/lang/special_functions.py)2
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py (renamed from tensorflow/contrib/autograph/lang/special_functions_test.py)2
-rw-r--r--tensorflow/python/autograph/operators/BUILD (renamed from tensorflow/contrib/autograph/operators/BUILD)3
-rw-r--r--tensorflow/python/autograph/operators/__init__.py (renamed from tensorflow/contrib/autograph/operators/__init__.py)32
-rw-r--r--tensorflow/python/autograph/operators/control_flow.py (renamed from tensorflow/contrib/autograph/operators/control_flow.py)2
-rw-r--r--tensorflow/python/autograph/operators/control_flow_test.py (renamed from tensorflow/contrib/autograph/operators/control_flow_test.py)2
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py (renamed from tensorflow/contrib/autograph/operators/data_structures.py)0
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py (renamed from tensorflow/contrib/autograph/operators/data_structures_test.py)2
-rw-r--r--tensorflow/python/autograph/operators/dispatch_context.py (renamed from tensorflow/contrib/autograph/operators/dispatch_context.py)0
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py (renamed from tensorflow/contrib/autograph/operators/py_builtins.py)4
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py (renamed from tensorflow/contrib/autograph/operators/py_builtins_test.py)4
-rw-r--r--tensorflow/python/autograph/operators/slices.py (renamed from tensorflow/contrib/autograph/operators/slices.py)9
-rw-r--r--tensorflow/python/autograph/operators/slices_test.py (renamed from tensorflow/contrib/autograph/operators/slices_test.py)17
-rw-r--r--tensorflow/python/autograph/pyct/BUILD (renamed from tensorflow/contrib/autograph/pyct/BUILD)0
-rw-r--r--tensorflow/python/autograph/pyct/__init__.py (renamed from tensorflow/contrib/autograph/pyct/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/anno.py (renamed from tensorflow/contrib/autograph/pyct/anno.py)0
-rw-r--r--tensorflow/python/autograph/pyct/anno_test.py (renamed from tensorflow/contrib/autograph/pyct/anno_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/ast_util.py (renamed from tensorflow/contrib/autograph/pyct/ast_util.py)4
-rw-r--r--tensorflow/python/autograph/pyct/ast_util_test.py (renamed from tensorflow/contrib/autograph/pyct/ast_util_test.py)10
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py (renamed from tensorflow/contrib/autograph/pyct/cfg.py)2
-rw-r--r--tensorflow/python/autograph/pyct/cfg_test.py (renamed from tensorflow/contrib/autograph/pyct/cfg_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/BUILD (renamed from tensorflow/contrib/autograph/pyct/common_transformers/BUILD)2
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/__init__.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/anf.py)4
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf_test.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py)8
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py (renamed from tensorflow/contrib/autograph/pyct/compiler.py)2
-rw-r--r--tensorflow/python/autograph/pyct/compiler_test.py (renamed from tensorflow/contrib/autograph/pyct/compiler_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py (renamed from tensorflow/contrib/autograph/pyct/inspect_utils.py)0
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py (renamed from tensorflow/contrib/autograph/pyct/inspect_utils_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py (renamed from tensorflow/contrib/autograph/pyct/origin_info.py)6
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py (renamed from tensorflow/contrib/autograph/pyct/origin_info_test.py)8
-rw-r--r--tensorflow/python/autograph/pyct/parser.py (renamed from tensorflow/contrib/autograph/pyct/parser.py)0
-rw-r--r--tensorflow/python/autograph/pyct/parser_test.py (renamed from tensorflow/contrib/autograph/pyct/parser_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer.py (renamed from tensorflow/contrib/autograph/pyct/pretty_printer.py)0
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer_test.py (renamed from tensorflow/contrib/autograph/pyct/pretty_printer_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/qual_names.py (renamed from tensorflow/contrib/autograph/pyct/qual_names.py)4
-rw-r--r--tensorflow/python/autograph/pyct/qual_names_test.py (renamed from tensorflow/contrib/autograph/pyct/qual_names_test.py)10
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/BUILD (renamed from tensorflow/contrib/autograph/pyct/static_analysis/BUILD)16
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/__init__.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/activity.py)8
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/annos.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/annos.py)0
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/live_values.py)6
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/liveness.py)8
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py)8
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/type_info.py)6
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/templates.py (renamed from tensorflow/contrib/autograph/pyct/templates.py)14
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py (renamed from tensorflow/contrib/autograph/pyct/templates_test.py)42
-rw-r--r--tensorflow/python/autograph/pyct/testing/BUILD (renamed from tensorflow/contrib/autograph/pyct/testing/BUILD)6
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen.py (renamed from tensorflow/contrib/autograph/pyct/testing/codegen.py)2
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen_test.py (renamed from tensorflow/contrib/autograph/pyct/testing/codegen_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/transformer.py (renamed from tensorflow/contrib/autograph/pyct/transformer.py)6
-rw-r--r--tensorflow/python/autograph/pyct/transformer_test.py (renamed from tensorflow/contrib/autograph/pyct/transformer_test.py)6
-rw-r--r--tensorflow/python/autograph/utils/BUILD (renamed from tensorflow/contrib/autograph/utils/BUILD)2
-rw-r--r--tensorflow/python/autograph/utils/__init__.py (renamed from tensorflow/contrib/autograph/utils/__init__.py)16
-rw-r--r--tensorflow/python/autograph/utils/context_managers.py (renamed from tensorflow/contrib/autograph/utils/context_managers.py)0
-rw-r--r--tensorflow/python/autograph/utils/context_managers_test.py (renamed from tensorflow/contrib/autograph/utils/context_managers_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/misc.py (renamed from tensorflow/contrib/autograph/utils/misc.py)0
-rw-r--r--tensorflow/python/autograph/utils/misc_test.py (renamed from tensorflow/contrib/autograph/utils/misc_test.py)6
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch.py (renamed from tensorflow/contrib/autograph/utils/multiple_dispatch.py)12
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch_test.py (renamed from tensorflow/contrib/autograph/utils/multiple_dispatch_test.py)31
-rw-r--r--tensorflow/python/autograph/utils/py_func.py (renamed from tensorflow/contrib/autograph/utils/py_func.py)0
-rw-r--r--tensorflow/python/autograph/utils/py_func_test.py (renamed from tensorflow/contrib/autograph/utils/py_func_test.py)10
-rw-r--r--tensorflow/python/autograph/utils/tensor_list.py (renamed from tensorflow/contrib/autograph/utils/tensor_list.py)0
-rw-r--r--tensorflow/python/autograph/utils/tensor_list_test.py (renamed from tensorflow/contrib/autograph/utils/tensor_list_test.py)10
-rw-r--r--tensorflow/python/autograph/utils/tensors.py (renamed from tensorflow/contrib/autograph/utils/tensors.py)0
-rw-r--r--tensorflow/python/autograph/utils/tensors_test.py (renamed from tensorflow/contrib/autograph/utils/tensors_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/testing.py (renamed from tensorflow/contrib/autograph/utils/testing.py)0
-rw-r--r--tensorflow/python/autograph/utils/type_check.py (renamed from tensorflow/contrib/autograph/utils/type_check.py)0
-rw-r--r--tensorflow/python/autograph/utils/type_check_test.py (renamed from tensorflow/contrib/autograph/utils/type_check_test.py)2
-rw-r--r--tensorflow/python/client/session_test.py4
-rw-r--r--tensorflow/python/client/timeline_test.py2
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py22
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py16
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py28
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py18
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py47
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py16
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py26
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py12
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/util/convert_test.py16
-rw-r--r--tensorflow/python/data/util/nest.py1
-rw-r--r--tensorflow/python/data/util/sparse_test.py2
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/benchmarks_test.py20
-rw-r--r--tensorflow/python/eager/function.py128
-rw-r--r--tensorflow/python/eager/function_test.py223
-rw-r--r--tensorflow/python/eager/graph_only_ops_test.py4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc312
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py25
-rw-r--r--tensorflow/python/eager/tape_test.py4
-rw-r--r--tensorflow/python/estimator/BUILD4
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py16
-rw-r--r--tensorflow/python/estimator/canned/head_test.py208
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py34
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io_test.py24
-rw-r--r--tensorflow/python/estimator/keras_test.py166
-rw-r--r--tensorflow/python/feature_column/feature_column.py25
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py34
-rw-r--r--tensorflow/python/framework/error_interpolation.py2
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py7
-rw-r--r--tensorflow/python/framework/file_system_test.py2
-rw-r--r--tensorflow/python/framework/function_test.py10
-rw-r--r--tensorflow/python/framework/importer_test.py18
-rw-r--r--tensorflow/python/framework/meta_graph_test.py9
-rw-r--r--tensorflow/python/framework/ops.py20
-rw-r--r--tensorflow/python/framework/ops_test.py50
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.cc13
-rw-r--r--tensorflow/python/framework/sparse_tensor_test.py6
-rw-r--r--tensorflow/python/framework/subscribe_test.py14
-rw-r--r--tensorflow/python/framework/tensor_util_test.py2
-rw-r--r--tensorflow/python/framework/test_util.py82
-rwxr-xr-xtensorflow/python/keras/BUILD5
-rw-r--r--tensorflow/python/keras/backend.py23
-rw-r--r--tensorflow/python/keras/backend_test.py2
-rw-r--r--tensorflow/python/keras/callbacks_test.py34
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py77
-rw-r--r--tensorflow/python/keras/engine/network.py5
-rw-r--r--tensorflow/python/keras/engine/saving.py2
-rw-r--r--tensorflow/python/keras/engine/saving_test.py38
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py4
-rw-r--r--tensorflow/python/keras/engine/topology_test.py36
-rw-r--r--tensorflow/python/keras/engine/training.py282
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py356
-rw-r--r--tensorflow/python/keras/engine/training_test.py134
-rw-r--r--tensorflow/python/keras/engine/training_utils.py12
-rw-r--r--tensorflow/python/keras/layers/convolutional.py71
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py4
-rw-r--r--tensorflow/python/keras/layers/gru_test.py8
-rw-r--r--tensorflow/python/keras/layers/lstm_test.py22
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py4
-rw-r--r--tensorflow/python/keras/layers/simplernn_test.py8
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py14
-rw-r--r--tensorflow/python/keras/models.py2
-rw-r--r--tensorflow/python/keras/optimizers_test.py20
-rw-r--r--tensorflow/python/keras/testing_utils.py6
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py6
-rw-r--r--tensorflow/python/keras/utils/data_utils.py9
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py1
-rw-r--r--tensorflow/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/accumulate_n_test.py12
-rw-r--r--tensorflow/python/kernel_tests/ackermann_test.py2
-rw-r--r--tensorflow/python/kernel_tests/argmax_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py56
-rw-r--r--tensorflow/python/kernel_tests/as_string_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/atrous_convolution_test.py2
-rw-r--r--tensorflow/python/kernel_tests/attention_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/barrier_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/base64_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py4
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/batchtospace_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/betainc_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py140
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py44
-rw-r--r--tensorflow/python/kernel_tests/candidate_sampler_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/checkpoint_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py28
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py15
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py38
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py28
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py52
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py323
-rw-r--r--tensorflow/python/kernel_tests/conv1d_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_transpose_test.py8
-rw-r--r--tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py10
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py4
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/cross_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py56
-rw-r--r--tensorflow/python/kernel_tests/decode_bmp_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/decode_compressed_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py55
-rw-r--r--tensorflow/python/kernel_tests/decode_image_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_png_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_raw_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py8
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/division_future_test.py2
-rw-r--r--tensorflow/python/kernel_tests/division_past_test.py2
-rw-r--r--tensorflow/python/kernel_tests/duplicate_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py60
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/fft_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py128
-rw-r--r--tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/fractional_max_pool_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/gradient_correctness_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_n_op_py_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py10
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/inplace_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py412
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/lookup_ops_test.py156
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py216
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py16
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/metrics_test.py258
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/padding_fifo_queue_test.py124
-rw-r--r--tensorflow/python/kernel_tests/parse_single_example_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py40
-rw-r--r--tensorflow/python/kernel_tests/priority_queue_test.py20
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py14
-rw-r--r--tensorflow/python/kernel_tests/reduce_join_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py30
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py27
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py36
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/session_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/sets_test.py10
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py34
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py23
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/spacetobatch_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py40
-rw-r--r--tensorflow/python/kernel_tests/sparse_cross_op_test.py34
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py16
-rw-r--r--tensorflow/python/kernel_tests/sparsemask_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_join_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_length_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/string_strip_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/string_to_number_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py149
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/summary_tensor_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py20
-rw-r--r--tensorflow/python/kernel_tests/unstack_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/variable_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py60
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py58
-rw-r--r--tensorflow/python/kernel_tests/weights_broadcast_test.py8
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py10
-rw-r--r--tensorflow/python/ops/array_grad.py19
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py6
-rw-r--r--tensorflow/python/ops/check_ops.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py4
-rw-r--r--tensorflow/python/ops/ctc_ops.py6
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py6
-rw-r--r--tensorflow/python/ops/distributions/categorical.py4
-rw-r--r--tensorflow/python/ops/gradients_impl.py8
-rw-r--r--tensorflow/python/ops/gradients_test.py21
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_addition.py432
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_circulant.py18
-rw-r--r--tensorflow/python/ops/math_ops.py39
-rw-r--r--tensorflow/python/ops/nn_ops.py2
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops_test.py2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py6
-rw-r--r--tensorflow/python/ops/parsing_ops.py3
-rw-r--r--tensorflow/python/ops/rnn.py4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py2
-rw-r--r--tensorflow/python/ops/string_ops.py5
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py1
-rw-r--r--tensorflow/python/platform/gfile.py2
-rwxr-xr-xtensorflow/python/pywrap_tfe.i5
-rw-r--r--tensorflow/python/saved_model/README.md15
-rw-r--r--tensorflow/python/summary/writer/event_file_writer.py2
-rw-r--r--tensorflow/python/tools/saved_model_cli.py7
-rw-r--r--tensorflow/python/training/adadelta_test.py4
-rw-r--r--tensorflow/python/training/adagrad_da_test.py10
-rw-r--r--tensorflow/python/training/adagrad_test.py16
-rw-r--r--tensorflow/python/training/adam_test.py10
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py10
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py6
-rw-r--r--tensorflow/python/training/checkpoint_ops_test.py18
-rw-r--r--tensorflow/python/training/checkpoint_utils_test.py24
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py43
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py99
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py2
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py2
-rw-r--r--tensorflow/python/training/ftrl_test.py28
-rw-r--r--tensorflow/python/training/gradient_descent_test.py18
-rw-r--r--tensorflow/python/training/input_test.py94
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py2
-rw-r--r--tensorflow/python/training/momentum_test.py14
-rw-r--r--tensorflow/python/training/monitored_session_test.py58
-rw-r--r--tensorflow/python/training/moving_averages_test.py30
-rw-r--r--tensorflow/python/training/optimizer_test.py8
-rw-r--r--tensorflow/python/training/proximal_adagrad_test.py18
-rw-r--r--tensorflow/python/training/proximal_gradient_descent_test.py16
-rw-r--r--tensorflow/python/training/queue_runner_test.py26
-rw-r--r--tensorflow/python/training/rmsprop_test.py4
-rw-r--r--tensorflow/python/training/saver_test.py54
-rw-r--r--tensorflow/python/training/session_manager_test.py28
-rw-r--r--tensorflow/python/training/slot_creator_test.py14
-rw-r--r--tensorflow/python/training/supervisor_test.py6
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py2
-rw-r--r--tensorflow/python/util/memory.py45
-rw-r--r--tensorflow/python/util/nest_test.py2
-rw-r--r--tensorflow/python/util/tf_inspect.py5
-rw-r--r--tensorflow/python/util/tf_should_use_test.py5
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc24
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h7
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt26
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt26
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt8
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.0483
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu43
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh4
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh14
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh77
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_mkl.sh2
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v0_11.py16
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v1_10.py2
-rw-r--r--tensorflow/tools/dockerfiles/README.md6
-rw-r--r--tensorflow/tools/docs/parser.py2
-rw-r--r--tensorflow/tools/pip_package/BUILD20
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh3
-rwxr-xr-xtensorflow/workspace.bzl8
-rw-r--r--third_party/flatbuffers/BUILD.bazel1
-rw-r--r--third_party/flatbuffers/build_defs.bzl19
-rw-r--r--third_party/jpeg/jpeg.BUILD139
-rw-r--r--third_party/llvm/llvm.bzl172
-rw-r--r--third_party/nasm.BUILD5
-rw-r--r--third_party/toolchains/BUILD15
-rw-r--r--third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/WORKSPACE2
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD1268
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl33
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/cuda/cuda_config.h26
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD73
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/CROSSTOOL1410
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/clang/bin/crosstool_wrapper_driver_is_not_gcc264
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.bat20
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.py192
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/nccl2/BUILD25
-rw-r--r--third_party/toolchains/preconfig/ubuntu14.04/nccl2/WORKSPACE2
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/py3/BUILD176
-rw-r--r--third_party/toolchains/preconfig/ubuntu14.04/py3/WORKSPACE2
1138 files changed, 33805 insertions, 13790 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 763ef3b279..bdc23795e5 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,9 @@
+# Release 1.10.1
+## Bug Fixes and Other Changes
+
+* `tf.keras`:
+ * Fixing keras on Cloud TPUs. No new binaries will be built for Windows.
+
# Release 1.10.0
## Major Features And Improvements
diff --git a/configure.py b/configure.py
index 361bd4764d..52a513779e 100644
--- a/configure.py
+++ b/configure.py
@@ -852,7 +852,7 @@ def set_tf_cuda_version(environ_cp):
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
- (tf_cuda_version, cuda_toolkit_path_full))
+ (tf_cuda_version, cuda_toolkit_paths_full))
environ_cp['TF_CUDA_VERSION'] = ''
environ_cp['CUDA_TOOLKIT_PATH'] = ''
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 173bbea596..79811ceae5 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index c046bd66cd..c195c9e01c 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index a2c5a42c11..f68f8a3e90 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/base64.h"
diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h
index a085e1d6e2..0717e7dd4b 100644
--- a/tensorflow/cc/framework/ops.h
+++ b/tensorflow/cc/framework/ops.h
@@ -150,7 +150,7 @@ class Input {
Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), shape);
- if (t.NumElements() != v.size()) {
+ if (t.NumElements() != static_cast<int64>(v.size())) {
status = errors::InvalidArgument(
"Cannot construct a tensor with ", t.NumElements(),
" from an initializer list with ", v.size(), " elements");
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index bd270045e3..cf5c04ac4b 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -20,7 +20,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
-#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 8d94f5495c..7a0932d44d 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -231,6 +231,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index dd2b151098..7ac90fb8a9 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) {
string hlo_profile_as_string =
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
/*clock_rate_ghz=*/1.0);
- VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+ VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
+
+ // Strip away identifier details from the profile string to avoid this test
+ // being a change detector for xla internals. Identifiers such as '%dot.0.7'
+ // just become '%dot'.
+ RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
+ VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
absl::StrSplit(hlo_profile_as_string, '\n');
@@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto add_profile_line = HasSubstr(
- "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto tuple_profile_line = HasSubstr(
- "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
- "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
- auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+ "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+ "f32[2,2]{1,0} %add)");
+ auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 1c9d30d7b0..b95b063348 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -93,8 +92,9 @@ Status Main(const MainFlags& flags) {
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
- TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
- StringPiece(obj.data(), obj.size())));
+ TF_RETURN_IF_ERROR(
+ WriteStringToFile(env, flags.out_function_object,
+ absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index a989f15a1c..f4e1bc5e83 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -265,6 +265,7 @@ cc_library(
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
+ "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
@@ -362,6 +363,7 @@ cc_library(
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
+ "encapsulate_xla_computations_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -370,6 +372,7 @@ cc_library(
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
+ "encapsulate_xla_computations_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@@ -396,6 +399,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@@ -474,6 +478,7 @@ tf_cc_test(
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
+ "encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@@ -489,7 +494,9 @@ tf_cc_test(
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ae7a22f451..e0632ff7e4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+void SortControlInputs(GraphDef* gdef) {
+ int64 num_nodes = gdef->node_size();
+ for (int64 i = 0; i < num_nodes; ++i) {
+ NodeDef* node = gdef->mutable_node(i);
+ // Stable sort control inputs and leave the order of data inputs unchanged.
+ std::stable_sort(node->mutable_input()->begin(),
+ node->mutable_input()->end(),
+ [](const string& a, const string& b) {
+ bool a_is_control = absl::StartsWith(a, "^");
+ bool b_is_control = absl::StartsWith(b, "^");
+ return (!a_is_control && b_is_control) ||
+ (a_is_control && b_is_control && a < b);
+ });
+ }
+}
+
namespace {
bool AreAllParentsGuaranteedConst(
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 926589546f..90354a801a 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
+// Sorts each node's control inputs by their names. This guarantees that for two
+// structually equivalent GraphDefs, we get the same traversal ordering on
+// node's control input fields.
+// TODO(hpucha): Move the utilities to a more appropriate place.
+void SortControlInputs(GraphDef* gdef);
+
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
new file mode 100644
index 0000000000..97ef8cd3cb
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -0,0 +1,360 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+
+namespace tensorflow {
+
+const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
+ "_xla_compile_id";
+
+namespace {
+
+const char* const kXlaClusterOutput = "XlaClusterOutput";
+
+// Checks if a graph node is marked to be a guaranteed constant.
+bool is_guaranteed_constant(const Node& n) {
+ bool guaranteed_constant = false;
+ if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
+ .ok()) {
+ return false;
+ }
+ return guaranteed_constant;
+}
+
+// Finds the `index` of an _Arg or _Retval node.
+Status GetIndexAttr(const Node& n, int num_args, int* index) {
+ TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
+ if (*index < 0 || *index >= num_args) {
+ return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
+ *index);
+ }
+ return Status::OK();
+}
+
+// Returns the data type of the destination of an edge.
+DataType EdgeType(const Edge* edge) {
+ return edge->dst()->input_type(edge->dst_input());
+}
+
+// Adds the control inputs of `node` to `*deps`.
+void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.in_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->src());
+ }
+ }
+}
+
+// Adds the control outputs of `node` to `*deps`.
+void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.out_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->dst());
+ }
+ }
+}
+
+// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
+// the arguments into the order expected by XlaLaunch computations:
+// 1) arguments
+// 2) resource variable arguments
+// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
+// of the arguments.
+//
+// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
+Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation,
+ NodeDef* call_def) {
+ Graph* graph = graph_ptr->get();
+ const int num_args = input_permutation->size();
+ const int num_retvals = output_permutation->size();
+
+ std::vector<Node*> args;
+ std::vector<Node*> retvals;
+ args.reserve(num_args);
+ retvals.reserve(num_retvals);
+ for (Node* n : graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ // Check if this is a guaranteed constant.
+ if (is_guaranteed_constant(*n)) {
+ return errors::InvalidArgument(
+ "Guaranteed constants are not supported (", n->name(), ")");
+ }
+ args.push_back(n);
+ } else if (n->type_string() == "_Retval") {
+ retvals.push_back(n);
+ }
+ }
+
+ if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
+ return errors::InvalidArgument("Missing or non-consecutive arguments");
+ }
+
+ // Reorders the arguments.
+ std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
+ // Non-resources appear before resources
+ bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
+ bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
+ // Uses the name as a tiebreaker so the output is deterministic.
+ StringPiece a_name(a->name());
+ StringPiece b_name(b->name());
+ return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
+ });
+
+ // Sorts the retvals by name so the order is deterministic.
+ std::sort(retvals.begin(), retvals.end(),
+ [](Node* a, Node* b) { return a->name() < b->name(); });
+
+ // Computes the permutation to produce the correct argument order, and update
+ // the argument indices.
+ int variable_start_index = num_args;
+ for (int i = 0; i < num_args; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
+ if (args[i]->output_type(0) == DT_RESOURCE &&
+ variable_start_index == num_args) {
+ variable_start_index = i;
+ }
+ (*input_permutation)[index] = i;
+ args[i]->AddAttr("index", i);
+ }
+ VLOG(4) << "variable_start_index: " << variable_start_index;
+
+ // Computes the permutation to produce the correct retval order, and update
+ // the argument indices.
+ for (int i = 0; i < num_retvals; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
+ (*output_permutation)[index] = i;
+ retvals[i]->AddAttr("index", i);
+ }
+
+ AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
+ call_def);
+ AddNodeAttr("_variable_start_index", variable_start_index, call_def);
+
+ // Uniquify the function name.
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+
+ // Before serialization, sort each node's control inputs to achieve
+ // determinism. Sorting control inputs could help (but not necessarily) create
+ // a deterministic serialization and fingerprint. Other sources of
+ // nondeterminism include unstable node ordering.
+ SortControlInputs(&gdef);
+ // Fingerprint the function.
+ // Nondeterminism in serialization would not lead to incorrect results, but
+ // may cause spurious cache misses. DeterministicSerialization is a
+ // best-effort deterministic serialization.
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
+ call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Check for undeclared outputs before Encapsulation, so we can give a better
+ // error message.
+ // TODO(phawkins): merge this with the encapsulation code to avoid the extra
+ // O(n) pass over the edges.
+ for (const Edge* e : (*graph)->edges()) {
+ if (!e->IsControlEdge() &&
+ e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
+ e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+ e->dst()->type_string() != kXlaClusterOutput) {
+ return errors::InvalidArgument(
+ "Undeclared output of XLA computation. A common cause of this error "
+ "is variable initializers that depend on the XLA computation. Edge: ",
+ e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
+ e->dst_input());
+ }
+ }
+
+ auto output = absl::make_unique<Graph>((*graph)->op_registry());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, "", **graph, RewriteSubgraph,
+ /*reuse_existing_functions=*/true, &output, flib_def),
+ "EncapsulateXlaComputationsPass failed");
+ graph->swap(output);
+ return Status::OK();
+}
+
+/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
+ Graph* graph) {
+ // Finds all of the XlaLaunch function calls, to avoid mutating the graph
+ // while iterating.
+ std::vector<Node*> launch_nodes;
+ for (Node* n : graph->nodes()) {
+ string name;
+ if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+ launch_nodes.push_back(n);
+ }
+ }
+
+ // Replaces each launch function call together with its neighboring
+ // XlaClusterOutput nodes with a XlaLaunch node.
+ for (Node* launch : launch_nodes) {
+ int variable_start_index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
+ &variable_start_index));
+
+ std::vector<const Edge*> in_edges;
+ TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
+
+ const int num_inputs = in_edges.size();
+ const int num_variables = num_inputs - variable_start_index;
+ const int num_args = variable_start_index;
+
+ VLOG(4) << "Launch node '" << launch->name() << "'"
+ << " input edges: " << in_edges.size() << " num_args: " << num_args
+ << " num_variables: " << num_variables;
+
+ std::vector<Node*> nodes_to_remove = {launch};
+
+ // Data and control inputs to the new XlaLaunch node.
+ std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
+ gtl::FlatSet<Node*> control_inputs;
+ DataTypeVector arg_types(num_args);
+
+ AddControlInputs(*launch, &control_inputs);
+
+ for (int i = 0; i < num_args; ++i) {
+ const Edge* edge = in_edges[i];
+ data_inputs[i] = {edge->src(), edge->src_output()};
+ arg_types[i] = EdgeType(edge);
+ }
+
+ // Appends the variable inputs.
+ for (int i = 0; i < num_variables; ++i) {
+ int pos = variable_start_index + i;
+ const Edge* edge = in_edges[pos];
+ data_inputs[pos] = {edge->src(), edge->src_output()};
+ }
+
+ // Outputs.
+ const int num_outputs = launch->output_types().size();
+ gtl::FlatSet<Node*> control_outputs;
+ std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
+ DataTypeVector output_types(num_outputs);
+
+ for (const Edge* le : launch->out_edges()) {
+ if (le->IsControlEdge()) {
+ control_outputs.insert(le->dst());
+ } else {
+ TF_RET_CHECK(le->src_output() < num_outputs);
+ Node* output_node = le->dst();
+
+ TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
+ << le->DebugString();
+ nodes_to_remove.push_back(output_node);
+
+ for (const Edge* oe : output_node->out_edges()) {
+ TF_RET_CHECK(!oe->IsControlEdge());
+ data_outputs[le->src_output()].push_back(
+ {oe->dst(), oe->dst_input()});
+ }
+ output_types[le->src_output()] = output_node->input_type(0);
+
+ AddControlOutputs(*output_node, &control_outputs);
+ }
+ }
+
+ NodeDef def;
+ def.set_name(launch->name());
+
+ // Target the XLA CPU/GPU backends.
+ VLOG(2) << "Replacing with XlaLaunch";
+ def.set_op("XlaLaunch");
+ AddNodeAttr("Tconstants", DataTypeVector{}, &def);
+ AddNodeAttr("Targs", arg_types, &def);
+ AddNodeAttr("Nresources", num_variables, &def);
+ AddNodeAttr("Tresults", output_types, &def);
+ NameAttrList function;
+ function.set_name(launch->type_string());
+ AddNodeAttr("function", function, &def);
+
+ for (Node* node : nodes_to_remove) {
+ VLOG(2) << "Deleting node " << node->DebugString();
+ // Ensure that we do not attempt to add control edges to nodes that are
+ // deleted.
+ control_inputs.erase(node);
+ control_outputs.erase(node);
+ graph->RemoveNode(node);
+ }
+
+ Status status;
+ Node* xla_launch = graph->AddNode(def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ for (int i = 0; i < data_inputs.size(); ++i) {
+ graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
+ i);
+ }
+ for (Node* n : control_inputs) {
+ graph->AddControlEdge(n, xla_launch);
+ }
+ for (int i = 0; i < data_outputs.size(); ++i) {
+ for (const auto& successor : data_outputs[i]) {
+ graph->AddEdge(xla_launch, i, successor.first, successor.second);
+ }
+ }
+ for (Node* n : control_outputs) {
+ graph->AddControlEdge(xla_launch, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status EncapsulateXlaComputationsPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "EncapsulateXlaComputations(): "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
+ VLOG(1) << "EncapsulateXlaComputations() half-way: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
+ VLOG(1) << "EncapsulateXlaComputations() finished: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
+ **options.graph, options.flib_def);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
new file mode 100644
index 0000000000..99e9dfd598
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+// Rewrites computations generated by the xla.compile() Python code into
+// XlaLaunch nodes.
+//
+// xla.compile() does two main things:
+// a) marks operators that make up an XLA computation with the attribute
+// _xla_compile_id=XYZ, where XYZ is a unique key.
+// b) adds XlaClusterOutput nodes to represent outputs of the computation.
+// These nodes are not marked with the _xla_compile_id attribute.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+ namespace tensorflow {
+
+// Encapsulates nodes marked with the _xla_compile_id attribute into
+// XlaLaunch operators.
+class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
+ public:
+ static const char* const kXlaClusterAttr; // _xla_compile_id
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ // The following methods are public only for unit tests.
+
+ // This pass has two stages:
+ // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
+ // marked with the same _xla_compile_id attribute into functions. These
+ // functions contain the computations to be passed to XlaLaunch. During
+ // encapsulation, we sort the arguments into the order expected by
+ // XlaLaunch.
+ static Status Encapsulate(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // b) we rewrite the function calls generated in phase (a) into XlaLaunch
+ // operators. We also convert the XlaClusterOutput output nodes of the
+ // function call into the outputs of the XlaLaunch operator.
+ static Status BuildXlaLaunchOps(Graph* graph);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
new file mode 100644
index 0000000000..f643fb0cfe
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/test_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<Graph> MakeOuterGraph(
+ const FunctionLibraryDefinition& flib_def, const string& function) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NodeDef def;
+ TF_CHECK_OK(
+ NodeDefBuilder("launch0", function, &flib_def)
+ .Input(a.node()->name(), 0, DT_INT32)
+ .Input(b.node()->name(), 0, DT_FLOAT)
+ .Input(c.node()->name(), 0, DT_INT32)
+ .Input(d.node()->name(), 0, DT_FLOAT)
+ .Input(u.node()->name(), 0, DT_RESOURCE)
+ .Input(v.node()->name(), 0, DT_RESOURCE)
+ .Input(w.node()->name(), 0, DT_RESOURCE)
+ .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
+ .Attr("_variable_start_index", 4)
+ .Finalize(&def));
+
+ Status status;
+ Node* launch = scope.graph()->AddNode(def, &status);
+ TF_CHECK_OK(status);
+ TF_CHECK_OK(scope.DoShapeInference(launch));
+ scope.graph()->AddEdge(a.node(), 0, launch, 0);
+ scope.graph()->AddEdge(b.node(), 0, launch, 1);
+ scope.graph()->AddEdge(c.node(), 0, launch, 2);
+ scope.graph()->AddEdge(d.node(), 0, launch, 3);
+ scope.graph()->AddEdge(u.node(), 0, launch, 4);
+ scope.graph()->AddEdge(v.node(), 0, launch, 5);
+ scope.graph()->AddEdge(w.node(), 0, launch, 6);
+
+ auto out0 =
+ ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
+ auto out1 =
+ ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
+ auto out2 =
+ ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
+ auto out3 =
+ ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+// Makes an encapsulate body graph for use in tests.
+static std::unique_ptr<Graph> MakeBodyGraph() {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
+ auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
+ auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
+
+ auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
+ auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
+ auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, arg3);
+ add_attrs(g.node());
+
+ auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
+ b_identity, 0);
+ auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
+ auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
+ auto out3 =
+ ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
+ // Test that control edge insertion order doesn't affect the cache key
+ // (cluster name) generated by TPU encapsulate pass.
+ auto get_serialized_graph = [](bool control_input_reversed,
+ bool operand_reversed) -> string {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
+ auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
+
+ ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
+ : ops::Add(scope.WithOpName("E"), a1, a0);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
+ "launch0");
+ };
+ add_attrs(e.node());
+
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ auto get_node_in_graph = [&graph](Node* node) {
+ return graph->FindNodeId(node->id());
+ };
+ // Insert control edge in different order. The order should not affect
+ // the encapsulated or serialized graph.
+ if (!control_input_reversed) {
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ } else {
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ }
+ }
+ TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+ // Before serialization, sort control inputs first to remove
+ // nondeterminism.
+ SortControlInputs(&gdef);
+ string serialized;
+ SerializeToStringDeterministic(gdef, &serialized);
+ return serialized;
+ };
+
+ // Changing the order of control input shouldn't affect the graph generated.
+ EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
+ /*operand_reversed=*/false),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+
+ // Changing the order of data input should affect the graph generated.
+ EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/true),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+}
+
+TEST(EncapsulateXlaComputations, Encapsulate) {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
+ add_attrs(b_identity.node());
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), a, c);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, d);
+ add_attrs(g.node());
+
+ auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
+ auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
+ auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
+ auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+
+ std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
+ CopyGraph(*graph, graph_copy.get());
+
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+
+ std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
+ string function = index.at("launch0")->type_string();
+
+ // Tests the outer graph is as expected.
+ {
+ std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
+ GraphDef expected_def;
+ outer->ToGraphDef(&expected_def);
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
+ }
+
+ // Tests the encapsulated body graph is as expected.
+ {
+ std::unique_ptr<Graph> body = MakeBodyGraph();
+ GraphDef expected_body_def;
+ body->ToGraphDef(&expected_body_def);
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
+
+ EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
+ DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
+ result.arg_types);
+ EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
+ result.ret_types);
+ TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
+ }
+
+ // Encapsulates the same computation again, verifies we reuse the same
+ // function. Encapsulation should be deterministic to avoid recompilation.
+ TF_ASSERT_OK(
+ EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
+ std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
+ string function_copy = index_copy.at("launch0")->type_string();
+ EXPECT_EQ(function, function_copy);
+}
+
+TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
+ std::unique_ptr<Graph> body_graph = MakeBodyGraph();
+ FunctionDefLibrary flib;
+ TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+ std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
+
+ Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NameAttrList function;
+ function.set_name("launch0");
+ auto launch = ops::XlaLaunch(
+ scope.WithOpName("launch0"), std::initializer_list<Input>{},
+ std::initializer_list<Input>{a, b, c, d},
+ std::initializer_list<Input>{u, v, w},
+ DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
+
+ auto consumer0_a =
+ ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
+ auto consumer0_b =
+ ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
+ auto consumer0_c =
+ ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
+ auto consumer1 =
+ ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
+ auto consumer2 =
+ ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
+ auto consumer3 =
+ ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
+
+ GraphDef expected_def;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index c37b6112cc..3770eea6d0 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -15,12 +15,31 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
+// PRE_PLACEMENT passes:
+
+// EncapsulateXlaComputationsPass rewrites computations generated by the
+// xla.compile() Python code into XlaLaunch nodes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
+ EncapsulateXlaComputationsPass);
+
+// from
+// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
+// FunctionalizeControlFlowPass: 27
+//
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (XlaIf/XlaWhile). Following passes must
+// handle those FunctionDef correctly.
+
+// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
+
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 44caf0be52..e6cc6e52ae 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -443,7 +443,7 @@ Status FindCompilationCandidates(
!registration->requires_compilation) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
- OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def));
+ graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
// We need to be able to constant fold the nodes in
// compile_time_const_nodes given constant inputs (required by XLA) and
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 807ab51fd3..c59770a4c8 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@@ -633,7 +634,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
{
- auto BuildNoopNode = [](StringPiece name, Graph* graph) {
+ auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
NodeDefBuilder builder(name, "NoOp");
NodeDef def;
TF_CHECK_OK(builder.Finalize(&def));
@@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) {
EXPECT_EQ(clusters["shape"], "");
}
+TEST(XlaCompilationTest, RandomShapeWithFunc) {
+ Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
+
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/"Stateful_func", /*in_def=*/{},
+ /*out_def=*/{"out: int32"},
+ /*attr_def*/
+ {}, /*node_def=*/
+ {FunctionDefHelper::Const("shape_shape", 2),
+ FunctionDefHelper::Const("minval", 1),
+ FunctionDefHelper::Const("maxval", 20),
+ {{"shape"},
+ "RandomUniformInt",
+ {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
+ {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
+ /*ret_def=*/{{"out", "shape:output:0"}});
+
+ func.mutable_signature()->set_is_stateful(true);
+ *flib_def.add_function() = std::move(func);
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ NodeDef call_node;
+ call_node.set_name("fn_call");
+ call_node.set_op("Stateful_func");
+ Status status;
+ Node* call = root.graph()->AddNode(call_node, &status);
+ TF_ASSERT_OK(status);
+
+ Output shape = Output(call, 0);
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
+ flib_def);
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["fn_call"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d98ff..1a29c3caab 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
+REGISTER_OP("XlaClusterOutput")
+ .Input("input: T")
+ // Note: when replication is supported, this op will have N outputs.
+ .Output("outputs: T")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(
+ "Operator that connects the output of an XLA computation to other "
+ "consumer graph nodes.");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 94c96ac7c5..ba218f3315 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -18,7 +18,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
-#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 6d4160a968..af83c792e5 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name,
+ absl::string_view tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
- manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name),
- device, cpu_tensor, done);
+ manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
+ done);
}
void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 1effd6628f..df82421294 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
@@ -111,12 +110,9 @@ class XlaDeviceContext : public DeviceContext {
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
StatusCallback done) const override;
- // TODO(rlahaye): Replace StringPiece with absl::string_view when the
- // StringPiece->absl::string_view change is rolled forward.
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name, // non-ABSL OK
- Device* device, Tensor* cpu_tensor,
- StatusCallback done) override;
+ absl::string_view tensor_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) override;
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 050d827a09..97ed554171 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -277,9 +277,10 @@ tf_xla_py_test(
],
)
+# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors
tf_xla_py_test(
name = "concat_ops_test",
- size = "medium",
+ size = "large",
srcs = ["concat_ops_test.py"],
deps = [
":xla_test",
@@ -581,6 +582,7 @@ tf_xla_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -1197,7 +1199,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "xla_ops_test",
- size = "small",
+ size = "medium",
srcs = ["xla_ops_test.py"],
disabled_backends = ["cpu_ondemand"],
deps = [
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index df0f21471a..058576b3d4 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -98,7 +98,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -140,7 +140,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 7b114d4f85..a76f136736 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -4,88 +4,97 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
def all_backends():
- b = ["cpu"] + plugins.keys()
- if cuda_is_configured():
- return b + ["gpu"]
- else:
- return b
+ b = ["cpu"] + plugins.keys()
+ if cuda_is_configured():
+ return b + ["gpu"]
+ else:
+ return b
-def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
- disabled_backends=None, **kwargs):
- """Generates py_test targets, one per XLA backend.
+def tf_xla_py_test(
+ name,
+ srcs = [],
+ deps = [],
+ tags = [],
+ data = [],
+ main = None,
+ disabled_backends = None,
+ **kwargs):
+ """Generates py_test targets, one per XLA backend.
- This rule generates py_test() targets named name_backend, for each backend
- in all_backends(). The rule also generates a test suite with named `name` that
- tests all backends for the test.
+ This rule generates py_test() targets named name_backend, for each backend
+ in all_backends(). The rule also generates a test suite with named `name` that
+ tests all backends for the test.
- For example, the following rule generates test cases foo_test_cpu,
- foo_test_gpu, and a test suite name foo_test that tests both.
- tf_xla_py_test(
- name="foo_test",
- srcs="foo_test.py",
- deps=[...],
- )
+ For example, the following rule generates test cases foo_test_cpu,
+ foo_test_gpu, and a test suite name foo_test that tests both.
+ tf_xla_py_test(
+ name="foo_test",
+ srcs="foo_test.py",
+ deps=[...],
+ )
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- tags: Tags to apply to the generated targets.
- data: Data dependencies of the target.
- main: Same as py_test's main attribute.
- disabled_backends: A list of backends that should not be tested. Supported
- values include "cpu" and "gpu". If not specified, defaults to None.
- **kwargs: keyword arguments passed onto the generated py_test() rules.
- """
- if disabled_backends == None:
- disabled_backends = []
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ tags: Tags to apply to the generated targets.
+ data: Data dependencies of the target.
+ main: Same as py_test's main attribute.
+ disabled_backends: A list of backends that should not be tested. Supported
+ values include "cpu" and "gpu". If not specified, defaults to None.
+ **kwargs: keyword arguments passed onto the generated py_test() rules.
+ """
+ if disabled_backends == None:
+ disabled_backends = []
- enabled_backends = [b for b in all_backends() if b not in disabled_backends]
- test_names = []
- for backend in enabled_backends:
- test_name = "{}_{}".format(name, backend)
- backend_tags = ["tf_xla_{}".format(backend)]
- backend_args = []
- backend_deps = []
- backend_data = []
- if backend == "cpu":
- backend_args += [
- "--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
- ]
- elif backend == "gpu":
- backend_args += [
- "--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
- ]
- backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_args += ["--test_device=" + plugins[backend]["device"],
- "--types=" + plugins[backend]["types"]]
- backend_tags += plugins[backend]["tags"]
- backend_args += plugins[backend]["args"]
- backend_deps += plugins[backend]["deps"]
- backend_data += plugins[backend]["data"]
- else:
- fail("Unknown backend {}".format(backend))
+ enabled_backends = [b for b in all_backends() if b not in disabled_backends]
+ test_names = []
+ for backend in enabled_backends:
+ test_name = "{}_{}".format(name, backend)
+ backend_tags = ["tf_xla_{}".format(backend)]
+ backend_args = []
+ backend_deps = []
+ backend_data = []
+ if backend == "cpu":
+ backend_args += [
+ "--test_device=XLA_CPU",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ ]
+ elif backend == "gpu":
+ backend_args += [
+ "--test_device=XLA_GPU",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ ]
+ backend_tags += ["requires-gpu-sm35"]
+ elif backend in plugins:
+ backend_args += [
+ "--test_device=" + plugins[backend]["device"],
+ "--types=" + plugins[backend]["types"],
+ ]
+ backend_tags += plugins[backend]["tags"]
+ backend_args += plugins[backend]["args"]
+ backend_deps += plugins[backend]["deps"]
+ backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend {}".format(backend))
- native.py_test(
- name=test_name,
- srcs=srcs,
- srcs_version="PY2AND3",
- args=backend_args,
- main="{}.py".format(name) if main == None else main,
- data=data + backend_data,
- deps=deps + backend_deps,
- tags=tags + backend_tags,
- **kwargs
- )
- test_names.append(test_name)
- native.test_suite(name=name, tests=test_names)
+ native.py_test(
+ name = test_name,
+ srcs = srcs,
+ srcs_version = "PY2AND3",
+ args = backend_args,
+ main = "{}.py".format(name) if main == None else main,
+ data = data + backend_data,
+ deps = deps + backend_deps,
+ tags = tags + backend_tags,
+ **kwargs
+ )
+ test_names.append(test_name)
+ native.test_suite(name = name, tests = test_names)
-def generate_backend_suites(backends=[]):
- """Generates per-backend test_suites that run all tests for a backend."""
- if not backends:
- backends = all_backends()
- for backend in backends:
- native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
+def generate_backend_suites(backends = []):
+ """Generates per-backend test_suites that run all tests for a backend."""
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 9222db4b7e..c61965b97f 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MatrixBandPartTest(xla_test.XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase):
- def _testMatrixBandPart(self, dtype, shape):
- with self.cached_session():
- batch_shape = shape[:-2]
- mat = np.ones(shape).astype(dtype)
- batch_mat = np.tile(mat, batch_shape + [1, 1])
- for lower in -1, 0, 1, shape[-2] - 1:
- for upper in -1, 0, 1, shape[-1] - 1:
- band_np = mat
- if lower >= 0:
- band_np = np.triu(band_np, -lower)
- if upper >= 0:
- band_np = np.tril(band_np, upper)
- if batch_shape:
- band_np = np.tile(band_np, batch_shape + [1, 1])
-
- placeholder = array_ops.placeholder(dtype)
- with self.test_scope():
- band = array_ops.matrix_band_part(
- placeholder,
- constant_op.constant(lower, dtype=dtypes.int32),
- constant_op.constant(upper, dtype=dtypes.int32))
- feed_dict = {placeholder: batch_mat}
- self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
-
- def testMatrixBandPart(self):
+ @parameterized.parameters(
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 7
+ },
+ )
+ def testMatrixBandPart(self, batch_shape, rows, cols):
for dtype in self.float_types:
- for batch_shape in [[], [2,], [1, 3, 2]]:
- for rows in 1, 2, 7:
- for cols in 1, 2, 7:
- self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+ with self.cached_session():
+ mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
+ batch_mat = np.tile(mat, batch_shape + [1, 1])
+ for lower in -1, 0, 1, rows - 1:
+ for upper in -1, 0, 1, cols - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape:
+ band_np = np.tile(band_np, batch_shape + [1, 1])
+
+ placeholder = array_ops.placeholder(dtype)
+ with self.test_scope():
+ band = array_ops.matrix_band_part(
+ placeholder, constant_op.constant(lower, dtype=dtypes.int32),
+ constant_op.constant(upper, dtype=dtypes.int32))
+ feed_dict = {placeholder: batch_mat}
+ self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py
index 84c6777940..96e0b07475 100644
--- a/tensorflow/compiler/tests/reshape_op_test.py
+++ b/tensorflow/compiler/tests/reshape_op_test.py
@@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
('64_bit_index', dtypes.int64))
def testBasic(self, index_dtype):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
shape = constant_op.constant([3, 2], dtype=index_dtype)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 3f928a1bea..1e600c44e9 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+ def testDynamicSlice(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.dynamic_slice,
+ args=(np.arange(1000,
+ dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3, 2])),
+ expected=np.array(
+ np.array([[[573, 574], [583, 584], [593, 594]],
+ [[673, 674], [683, 684], [693, 694]]]),
+ dtype=dtype))
+
+ def testDynamicSliceWithIncorrectStartIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7]), np.array([2, 3, 4]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^start_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+ def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^size_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and size_indices has shape \[2\].*'))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 3821dced63..ba1e3b2b4f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -76,6 +76,7 @@ cc_library(
deps = [
":common",
":dump_graph",
+ ":functionalize_control_flow",
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
@@ -188,7 +189,6 @@ cc_library(
deps = [
":common",
":dump_graph",
- ":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
":side_effect_util",
@@ -215,7 +215,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
@@ -285,6 +284,7 @@ cc_library(
deps = [
":sharding_util",
":tf2xla_proto",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -480,6 +480,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -507,12 +508,24 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
+ name = "functionalize_control_flow_pass_registration",
+ srcs = [
+ "functionalize_control_flow_pass_registration.cc",
+ ],
+ deps = [
+ ":functionalize_control_flow",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "functionalize_while",
srcs = [
"functionalize_while.cc",
@@ -521,6 +534,7 @@ cc_library(
"functionalize_while.h",
],
deps = [
+ ":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
@@ -531,6 +545,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
@@ -545,6 +560,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
@@ -595,6 +611,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index 0911550f1f..db256e577a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
using xla::StatusOr;
@@ -217,10 +218,6 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
added_node_ancestorid_mapping_[node->id()] = id;
}
-const StateMap::CondState& StateMap::LookupState(const Node* node) const {
- return *LookupCondId(node);
-}
-
void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
string StateMap::CondStateToString(const Node* node) const {
@@ -642,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If");
+ NodeDefBuilder builder(name(), "If", library);
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast<int>(branch);
@@ -791,7 +788,6 @@ Status Conditional::BuildAndReplace(Graph* graph,
TF_RETURN_IF_ERROR(AddInputEdges(graph));
TF_RETURN_IF_ERROR(AddOutputEdges(graph));
TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
- for (Node* m : merges_) state_map_->MarkDead(m);
// Check that the if_node doesn't feed into itself.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@@ -1056,7 +1052,6 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
" has no non-dead inputs.");
}
state_map_.MarkDead(node);
- delete_nodes_.push_back(node->id());
VLOG(5) << "removing redundant merge: " << node->name();
while (!node->out_edges().empty()) {
const Edge* oe = *node->out_edges().begin();
@@ -1132,7 +1127,6 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
}
} else if (BranchType(switch_branch) != b) {
state_map_.MarkDead(dst_node);
- delete_nodes_.push_back(dst_node->id());
continue;
}
graph_->AddEdge(
@@ -1154,7 +1148,7 @@ Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
<< " @ " << state_map_.AncestorStateToString(dst);
- if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it");
+ if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
}
return Status::OK();
}
@@ -1184,23 +1178,62 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
return Status::OK();
}
-void FunctionalizeCond::DeleteReachableNodes() {
+void FunctionalizeCond::DeleteReachableAndDeadNodes(
+ const std::vector<int>& switch_ids, const std::vector<Node*>& merge_order) {
// Delete all nodes that have been extracted or are reachable from
// deleted/dead nodes. The input and outgoing edges should have already been
// removed.
+ std::deque<int> delete_nodes;
std::vector<bool> deleted(graph_->num_node_ids(), false);
// Don't try to delete source or sink nodes.
deleted[graph_->kSourceId] = true;
deleted[graph_->kSinkId] = true;
- while (!delete_nodes_.empty()) {
- int d_id = delete_nodes_.front();
- delete_nodes_.pop_front();
+
+ // All remaining Switch nodes are not reachable from a Merge node and
+ // removed. This is to account for dead Switch nodes.
+ for (int s_id : switch_ids) {
+ Node* s = graph_->FindNodeId(s_id);
+ if (s == nullptr) continue;
+ for (const Edge* e : s->out_edges()) {
+ // Control outputs of switch nodes (which are unconditionally executed if
+ // the switch is) are not removed as they need not be part of a
+ // conditional.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[s_id] = true;
+ graph_->RemoveNode(s);
+ }
+
+ // All merge nodes should have been transformed at this point and we remove
+ // them from the graph here.
+ for (Node* m : merge_order) {
+ for (const Edge* e : m->out_edges()) {
+ // Similar to control outputs of switch nodes don't remove control
+ // outputs of merge nodes.
+ // TODO(jpienaar): Check cases where output edges still exist here vs
+ // being removed in AddOutputEdges.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[m->id()] = true;
+ graph_->RemoveNode(m);
+ }
+
+ // Enqueue all the dead nodes.
+ for (Node* n : graph_->nodes()) {
+ if (state_map_.IsDead(state_map_.LookupCondId(n))) {
+ delete_nodes.push_back(n->id());
+ }
+ }
+
+ while (!delete_nodes.empty()) {
+ int d_id = delete_nodes.front();
+ delete_nodes.pop_front();
if (deleted[d_id]) continue;
Node* d = graph_->FindNodeId(d_id);
// Switch and Merge nodes could have been deleted already.
if (d == nullptr) continue;
for (const Edge* e : d->out_edges()) {
- delete_nodes_.push_back(e->dst()->id());
+ delete_nodes.push_back(e->dst()->id());
}
deleted[d_id] = true;
graph_->RemoveNode(d);
@@ -1274,7 +1307,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
}
TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
- if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
// Sort the merge nodes from innermost outwards.
SortMergeNodes(&merge_order);
@@ -1312,11 +1345,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
}
- // All remaining Switch nodes are not reachable from a Merge node and
- // removed. This is to account for dead Switch nodes.
- for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
- for (Node* m : merge_order) delete_nodes_.push_back(m->id());
- DeleteReachableNodes();
+ DeleteReachableAndDeadNodes(switch_ids, merge_order);
return Status::OK();
}
@@ -1331,8 +1360,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
state_map_.AncestorStateToString(n)));
}
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
- << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name),
- *graph_, library_);
+ << dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_cond_", name), *graph_,
+ library_);
}
Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 28301150ea..1899808940 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -91,10 +91,6 @@ class StateMap {
// Resets the AncestorId for a given node.
void ResetAncestorId(const Node* node, AncestorId id);
- // Returns the CondState for a Node.
- // REQUIRES: node has a non-empty CondState.
- const CondState& LookupState(const Node* node) const;
-
// Marks `node` as dead.
void MarkDead(const Node* node);
@@ -221,8 +217,10 @@ class FunctionalizeCond {
// nesting depth.
void SortMergeNodes(std::vector<Node*>* merge_order);
- // Deletes all nodes in/consumers of `delete_nodes_`.
- void DeleteReachableNodes();
+ // Deletes all nodes in/consumers reachable from switch/merge nodes that were
+ // extracted.
+ void DeleteReachableAndDeadNodes(const std::vector<int>& switch_ids,
+ const std::vector<Node*>& merge_order);
// Member used to unique the CondState to a unique CondId (AncestorState to a
// unique AncestorId) and keep track of CondState/CondId
@@ -232,9 +230,6 @@ class FunctionalizeCond {
// Mapping from merge nodes to predicate.
std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
- // Nodes to be deleted.
- std::deque<int> delete_nodes_;
-
FunctionLibraryDefinition* library_;
Graph* graph_;
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 5932be4e52..f792c52032 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,16 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -68,4 +73,146 @@ Status FunctionalizeControlFlow(Graph* graph,
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
+Status FunctionalizeControlFlowForFunction(
+ const string& func_name, const string& new_func_name,
+ const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+ FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+ std::map<string, string>* canonicalized_name_to_new_name) {
+ // Convert the function to Graph.
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
+ Status ret_status = Status::OK();
+ auto cleanup_handle = gtl::MakeCleanup([&]() {
+ auto s = flr->ReleaseHandle(handle);
+ if (!s.ok()) {
+ ret_status.Update(s);
+ }
+ });
+ const FunctionBody* body = flr->GetFunctionBody(handle);
+ const FunctionDef& fdef = body->fdef;
+
+ // If any node has associated functions, functionalize them first.
+ // Gather nodes with associated functions first, because rewriting those nodes
+ // might involve node deletion/addition. Avoid modifying nodes while iterating
+ // it.
+ std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
+ nodes_to_associated_functions;
+ for (auto* n : body->graph->nodes()) {
+ auto associated_functions = GetAssociatedFunctions(*n, flr);
+ if (!associated_functions.empty()) {
+ nodes_to_associated_functions.push_back({n, associated_functions});
+ }
+ }
+ for (auto iter : nodes_to_associated_functions) {
+ Node* n = iter.first;
+ auto associated_functions = iter.second;
+ for (auto& associated_function : associated_functions) {
+ string name = associated_function.func_name();
+ string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
+ string new_name;
+ if (iter != canonicalized_name_to_new_name->end()) {
+ // If we already functionalized this function, skip functionalization
+ // but still rewrite the node.
+ new_name = iter->second;
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ }
+ // Notice that if "n" is a function call, RewriteAssociatedFunction() will
+ // delete it and create a new node instead, making "n" an invalid pointer.
+ // That's fine because in that case, associated_functions will only have
+ // one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ body->graph, n, fld, associated_function, new_name));
+ }
+ }
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *body->graph, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+ *body->graph, fld);
+ }
+ FunctionDef functionalized_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
+
+ // Copy signature and ret from original FunctionDef.
+ *functionalized_fdef.mutable_signature() = fdef.signature();
+ *functionalized_fdef.mutable_ret() = fdef.ret();
+ functionalized_fdef.mutable_signature()->set_name(new_func_name);
+
+ // Add rewritten FunctionDef into library.
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
+
+ return ret_status;
+}
+
+Status FunctionalizeControlFlowPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph,
+ options.flib_def);
+ }
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+ new ProcessFunctionLibraryRuntime(
+ /*device_mgr=*/nullptr, options.session_options->env,
+ TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
+ FunctionLibraryRuntime* flr =
+ pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+ // Find XLA compile ops and its corresponding FunctionDef.
+ static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
+ new std::map<string, string>{
+ {"TPUCompile", "function"},
+ {"XlaLaunch", "function"},
+ };
+ std::map<string, string> canonicalized_name_to_new_name;
+ for (Node* n : graph->nodes()) {
+ auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
+ if (it == kNodeTypeToFunctionAttrMapping->end()) {
+ continue;
+ }
+ const string func_attr = it->second;
+ if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
+ kNodeTypeToFunctionAttrMapping->end()) {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
+ VLOG(2) << "Graph has node " << n->type_string()
+ << ". Corresponding function: " << func.name();
+ string new_func_name = options.flib_def->UniqueFunctionName(
+ absl::StrCat(func.name(), "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ func.name(), new_func_name, func.attr(), options.flib_def, flr,
+ &canonicalized_name_to_new_name));
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
+ }
+
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
+ options.flib_def);
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index 55600f2a8b..ba99205640 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -32,6 +33,14 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (If/While).
+class FunctionalizeControlFlowPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
new file mode 100644
index 0000000000..a10a9d0499
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
+
+namespace tensorflow {
+
+// This pass is required for some AOT backends and all JIT backends, so this
+// file exists as a separate lib and will be linked to both AOT and JIT.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27,
+ FunctionalizeControlFlowPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index c068a4110c..c3841f996f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
@@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
- std::initializer_list<Input>{less, y, x}, then_fn,
- else_fn, {DT_INT32});
+ auto if_op = ops::If(scope.WithOpName(op_name), less,
+ std::initializer_list<Input>{less, y, x}, {DT_INT32},
+ then_fn, else_fn);
auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
- // TODO(jpienaar): Create wrapper for IfOp.
- for (NodeDef& n : *expected.mutable_node()) {
- if (n.op() == "XlaIf") n.set_op("If");
- }
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
@@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
NameAttrList* body) {
for (const NodeDef& node : graph.node()) {
- if (node.op() == "XlaWhile") {
+ if (node.op() == "While") {
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
*cond = *result;
@@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
return Status::OK();
}
}
- return errors::NotFound("No XlaWhile node found in graph");
+ return errors::NotFound("No While node found in graph");
}
// Graph:
@@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) {
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{x, y}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{x, y}, cond_fn, body_fn);
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
GraphDef expected;
@@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
- auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
- std::initializer_list<Input>{zero, y, x, var},
- outer_cond_fn, outer_body_fn);
+ auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
+ std::initializer_list<Input>{zero, y, x, var},
+ outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
- ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
- std::initializer_list<Input>{one_j, arg1, arg2, arg3},
- inner_cond_fn, inner_body_fn);
+ ops::While(scope.WithOpName("outer/LoopCond_1"),
+ std::initializer_list<Input>{one_j, arg1, arg2, arg3},
+ inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 7f45e3bffa..7c3ad448ef 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
@@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
}
- // Builds the condition and body functions.
+ // Builds the condition and body functions. Notice that we call
+ // FunctionalizeCond() on cond_graph and body_graph because we might have
+ // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
+ // before they are encapsulated in FunctionDef.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ FixupSourceAndSinkEdges(cond_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+ FixupSourceAndSinkEdges(body_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -653,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
// There should be no cycle at this point, since while loops have been removed
// from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
+ // Check that the newly added While nodes don't feed into themselves.
for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
+ if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index bc2e640559..c019a28e89 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -81,7 +80,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
+ LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index ab7cac7100..e9f02201cf 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -55,17 +55,17 @@ namespace tensorflow {
// op registration infrastructure instead of FunctionLibraryRuntime.
class GraphCompiler {
public:
- GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
- Graph* graph, FunctionLibraryRuntime* flib,
+ GraphCompiler(XlaCompilationDevice* device, Graph* graph,
+ FunctionLibraryRuntime* flib,
ScopedStepContainer* step_container)
- : xla_context_(xla_context),
- device_(device),
+ : device_(device),
graph_(graph),
flib_(flib),
step_container_(step_container) {}
- // Compiles the graph. The results are written in `xla_context` that is passed
- // into the compiler.
+ // Compiles the graph. The results are written in xla_context stored in the
+ // resource_manager of the 'XlaCompilationDevice' that's passed into the
+ // constructor.
Status Compile();
private:
@@ -82,7 +82,6 @@ class GraphCompiler {
// using `compiler_`.
Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
- XlaContext* xla_context_;
XlaCompilationDevice* device_;
Graph* graph_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index df17da4c1c..0d9a768a6f 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -66,6 +66,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ if (DataTypeIsUnsigned(dtype)) {
+ return xla::Div(x, y);
+ }
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605104..0ae23aa6df 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
+ std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+ // Concat is associative, so it can be split into many operations when too
+ // many arguments are in a single op. This is a temporary workaround for
+ // b/112613927 where too many parameters in an XlaLaunchOp later result in
+ // too many parameters to a single GPU kernel.
+ if (i && i % kMaxConcatArgsPerOp == 0) {
+ partial_concats.push_back(
+ xla::ConcatInDim(ctx->builder(), input_data, axis));
+ input_data.clear();
+ }
}
+ // Add any inputs that have not been put into another concat yet.
+ partial_concats.insert(partial_concats.end(), input_data.begin(),
+ input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+ // Don't add an additional "identity" concatenate for better readibility of
+ // IR.
+ if (partial_concats.size() == 1) {
+ ctx->SetOutput(0, partial_concats.front());
+ } else {
+ ctx->SetOutput(0,
+ xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+ }
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index a3389d5b90..4af1e8b44c 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
- VLOG(3) << "DynamicUpdateSliceOp::Compile";
+ DataType index_type = ctx->InputType("indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
- DataType index_type = input_type(2);
- OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
- errors::InvalidArgument("index must be int32 or int64"));
-
- const TensorShape input_shape = ctx->InputShape(0);
- const TensorShape update_shape = ctx->InputShape(1);
- const TensorShape index_shape = ctx->InputShape(2);
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape update_shape = ctx->InputShape("update");
+ const TensorShape index_shape = ctx->InputShape("indices");
OP_REQUIRES(
ctx,
@@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::XlaOp result =
- xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
+ xla::XlaOp result = xla::DynamicUpdateSlice(
+ ctx->Input("input"), ctx->Input("update"), ctx->Input("indices"));
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp);
+class DynamicSliceOp : public XlaOpKernel {
+ public:
+ explicit DynamicSliceOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType index_type = ctx->InputType("start_indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
+ CHECK(index_type == ctx->InputType("size_indices"));
+
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape start_indices_shape = ctx->InputShape("start_indices");
+ const TensorShape size_indices_shape = ctx->InputShape("size_indices");
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(start_indices_shape) &&
+ start_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "start_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and start_indices has shape ",
+ start_indices_shape.DebugString()));
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(size_indices_shape) &&
+ size_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "size_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and size_indices has shape ",
+ size_indices_shape.DebugString()));
+
+ std::vector<int64> size_indices;
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices));
+ xla::XlaOp result = xla::DynamicSlice(
+ ctx->Input("input"), ctx->Input("start_indices"), size_indices);
+ ctx->SetOutput(0, result);
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"),
+ DynamicSliceOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 22a45b2a11..3d81ae9eb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index c267848524..804671fbc7 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
+ literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U32:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
+ literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
case xla::U64:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
+ literal = xla::LiteralUtil::CreateR0<uint64>(value);
break;
case xla::S8:
- literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
+ literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S32:
- literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
+ literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
case xla::S64:
- literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
+ literal = xla::LiteralUtil::CreateR0<int64>(value);
break;
case xla::F32:
- literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
+ literal = xla::LiteralUtil::CreateR0<float>(value);
break;
case xla::F64:
- literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
+ literal = xla::LiteralUtil::CreateR0<double>(value);
break;
case xla::C64:
- literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
+ literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
- literal = std::move(
- *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
break;
case xla::F16:
- literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
- static_cast<xla::half>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index 7dc16b5a46..15f4c38da2 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -22,51 +22,61 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace {
TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
- {
- std::vector<int64> int64_values = {1, 2, 3};
- std::unique_ptr<xla::Literal> int64_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
- Tensor host_tensor;
- EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
- .error_message());
- EXPECT_EQ(
- "Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
- EXPECT_TRUE(
- LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
- .ok());
- test::ExpectTensorEqual<int64>(host_tensor,
- test::AsTensor<int64>(int64_values));
- }
+ std::vector<int64> int64_values = {1, 2, 3};
+ xla::Literal int64_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
+ Tensor host_tensor;
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
+ .error_message());
+ EXPECT_TRUE(
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
+ test::ExpectTensorEqual<int64>(host_tensor,
+ test::AsTensor<int64>(int64_values));
+}
+
+template <class T>
+using LiteralUtilTest = ::testing::Test;
+using Types =
+ ::testing::Types<std::pair<int8, qint8>, std::pair<uint8, quint8>,
+ std::pair<int16, qint16>, std::pair<uint16, quint16>,
+ std::pair<int32, qint32>>;
+
+TYPED_TEST_CASE(LiteralUtilTest, Types);
+
+TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) {
+ using int_type = typename TypeParam::first_type;
+ using qint_type = typename TypeParam::second_type;
- {
- // Repeat tests with int32.
- Tensor host_tensor;
- std::vector<int32> int32_values = {10, 11};
- std::unique_ptr<xla::Literal> int32_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
- EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
- .ok());
- test::ExpectTensorEqual<int32>(host_tensor,
- test::AsTensor<int32>(int32_values));
+ Tensor host_tensor;
+ std::vector<int_type> int_values = {10, 11};
+ xla::Literal int_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int_type>(int_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<int_type>::value, &host_tensor)
+ .ok());
+ test::ExpectTensorEqual<int_type>(host_tensor,
+ test::AsTensor<int_type>(int_values));
- EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
- .ok());
- std::vector<qint32> qint32_values = {10, 11};
- test::ExpectTensorEqual<qint32>(host_tensor,
- test::AsTensor<qint32>(qint32_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<qint_type>::value,
+ &host_tensor)
+ .ok());
+ std::vector<qint_type> qint_values = {10, 11};
+ test::ExpectTensorEqual<qint_type>(host_tensor,
+ test::AsTensor<qint_type>(qint_values));
- EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
- .error_message());
- }
+ EXPECT_EQ(
+ error::INVALID_ARGUMENT,
+ LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code());
}
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 68cfdc1785..02363500ef 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -105,6 +105,35 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
+REGISTER_OP("XlaDynamicSlice")
+ .Input("input: T")
+ .Input("start_indices: Tindices")
+ .Input("size_indices: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA DynamicSlice operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
+.
+
+DynamicSlice extracts a sub-array from the input array at dynamic
+start_indices. The size of the slice in each dimension is passed in
+size_indices, which specify the end point of exclusive slice intervals in each
+dimension -- [start, start + size). The shape of start_indices must be rank ==
+1, with dimension size equal to the rank of operand.
+
+input: A `Tensor` of type T.
+
+start_indices: Rank 1 tensor of N integers containing the starting indices of
+ the slice for each dimension. Value must be greater than or equal to zero.
+
+start_indices: List of N integers containing the slice size for each
+ dimension. Each value must be strictly greater than zero, and start + size
+ must be less
+)doc");
+
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 3626de375e..27dd18a9bb 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
name=name)
-def dynamic_slice(x, starts, sizes, name=None):
- # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
- # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
- # slice if the slice is out of bounds.
- return array_ops.slice(x, starts, sizes, name=name)
-
-
+dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 92577b5bc8..20f2ce2919 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
@@ -31,10 +30,11 @@ namespace tensorflow {
}
}
-static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
- auto* result = new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
+static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
+CreateResourceOpInfoMap() {
+ auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
- auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
+ auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
XlaResourceKind resource_kind) {
auto insert_result =
result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
@@ -103,17 +103,17 @@ static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
return result;
}
-static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
+static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap() {
- static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
+ static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
CreateResourceOpInfoMap();
return *op_info_map;
}
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
- const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
+ const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
GetStaticResourceOpInfoMap();
- auto it = op_infos.find(StringPiece(op.data(), op.length()));
+ auto it = op_infos.find(op);
return it == op_infos.end() ? nullptr : &it->second;
}
@@ -121,7 +121,7 @@ namespace resource_op_table_internal {
std::vector<absl::string_view> GetKnownResourceOps() {
std::vector<absl::string_view> result;
for (const auto& p : GetStaticResourceOpInfoMap()) {
- result.push_back(absl::string_view(p.first));
+ result.push_back(p.first);
}
absl::c_sort(result);
return result;
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a91b6..f31bfb45a2 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+ std::unordered_map<string, Node*> index;
+ for (Node* node : graph.nodes()) {
+ index[node->name()] = node;
+ }
+ return index;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae92ed..350a868568 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
} // namespace tensorflow
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \
+ do { \
+ string diff; \
+ EqualGraphDefOptions eq_options; \
+ eq_options.ignore_internal_attrs = false; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 7dbe3a0b58..b22d53805d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+
+ // Functionalize control flow.
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
+ // After control flow functionalization, we might have more FunctionDef's
+ // (then/else branch, loop body). Add them to the graph.
+ TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
+
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 56f7045a98..ab26d939cc 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
- auto x_global_or = client->TransferToServer(*x_literal);
- auto y_global_or = client->TransferToServer(*y_literal);
+ auto x_global_or = client->TransferToServer(x_literal);
+ auto y_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
@@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
- std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+ xla::Literal result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 211caf8736..d6f42bac86 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -25,9 +25,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
+const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
+
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const tf2xla::Feed& feed : config.feed()) {
@@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2);
}
+// TODO(b/77601805): add tests for associated function related stuff.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr) {
+ if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) {
+ return true;
+ }
+
+ if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ return false;
+ }
+
+ for (const auto& iter : node_def.attr()) {
+ if (iter.second.has_func()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr) {
+ std::vector<AssociatedFunctionInfo> results;
+ const string& op = node.type_string();
+ if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
+ // This is a function call node.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ } else {
+ // Collect all function attrs for the node.
+ for (auto& iter : node.attrs()) {
+ if (iter.second.has_func()) {
+ VLOG(2) << "Found function attr for node " << node.name() << ": "
+ << iter.first << " = " << iter.second.func().name();
+ results.emplace_back(AssociatedFunctionInfo(
+ iter.second.func().name(), iter.second.func().attr(), iter.first));
+ }
+ }
+ }
+ return results;
+}
+
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name) {
+ switch (associated_function.type()) {
+ case AssociatedFunctionInfo::kFunctionCallNode: {
+ // Change this node to call the new function.
+ NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
+ for (auto attr : node->attrs()) {
+ builder.Attr(attr.first, attr.second);
+ }
+ for (int i = 0; i < node->num_inputs(); i++) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
+ builder.Input(input_node->name(), i, node->input_type(i));
+ }
+ builder.Device(node->assigned_device_name().empty()
+ ? node->requested_device()
+ : node->assigned_device_name());
+ NodeDef node_def;
+ TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
+ Status s;
+ Node* new_node = graph->AddNode(node_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ for (auto edge : node->in_edges()) {
+ graph->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input());
+ }
+ for (auto edge : node->out_edges()) {
+ graph->AddEdge(new_node, edge->src_output(), edge->dst(),
+ edge->dst_input());
+ }
+ graph->RemoveNode(node);
+ break;
+ }
+ case AssociatedFunctionInfo::kFunctionAttr: {
+ // Change function attr to rewritten functions.
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
+ node->ClearAttr(associated_function.attr_name());
+ func.set_name(rewritten_function_name);
+ node->AddAttr(associated_function.attr_name(), func);
+ break;
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index dcddef8418..6065d0bb9a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <unordered_map>
-#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -60,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
+// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
+// a function call, or the node has function attrs).
+class AssociatedFunctionInfo {
+ public:
+ enum AssociatedFunctionType {
+ kFunctionCallNode = 0,
+ kFunctionAttr = 1,
+ };
+
+ // The node is a function call.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
+ : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
+
+ // The function is an attr of the node.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
+ const string& attr_name)
+ : type_(kFunctionAttr),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
+ AssociatedFunctionType type() const { return type_; }
+
+ const string& func_name() const { return func_name_; }
+
+ const string& attr_name() const { return attr_name_; }
+
+ const AttrValueMap& attrs() const { return attrs_; }
+
+ private:
+ // Available for all instances.
+ AssociatedFunctionType type_;
+ string func_name_;
+ AttrValueMap attrs_;
+
+ // Only available if the function is defined in an attr.
+ string attr_name_;
+};
+
+// Returns if the NodeDef has associated function.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr);
+
+// Gets functions associated with the node. Current cases:
+// 1. For function call node, its function name;
+// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr);
+
+// Changes associated functions for the node. Current cases:
+// 1. For function call node, creates a new node with the new function name and
+// remove the old node;
+// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name);
+
+// Attribute to mark nodes to be executed on host.
+extern const char kXlaOutsideCompilationAttrName[];
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index c969212a1b..d00b137662 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
*type = xla::PRED;
return Status::OK();
case tensorflow::DT_INT8:
+ case tensorflow::DT_QINT8:
*type = xla::S8;
return Status::OK();
case tensorflow::DT_INT16:
+ case tensorflow::DT_QINT16:
*type = xla::S16;
return Status::OK();
case tensorflow::DT_INT32:
+ case tensorflow::DT_QINT32:
*type = xla::S32;
return Status::OK();
case tensorflow::DT_INT64:
*type = xla::S64;
return Status::OK();
case tensorflow::DT_UINT8:
+ case tensorflow::DT_QUINT8:
*type = xla::U8;
return Status::OK();
case tensorflow::DT_UINT16:
+ case tensorflow::DT_QUINT16:
*type = xla::U16;
return Status::OK();
case tensorflow::DT_UINT32:
@@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_COMPLEX64:
*type = xla::C64;
return Status::OK();
- case tensorflow::DT_QUINT8:
- *type = xla::U8;
- return Status::OK();
- case tensorflow::DT_QINT32:
- *type = xla::S32;
- return Status::OK();
default:
return errors::InvalidArgument(
"Unsupported type in DataTypeToPrimitiveType ",
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index dcb455779d..739e47778a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
GetFunctionBody(function, flib_runtime_, fbody),
"Local lookup failed with: ", status.error_message());
+ VLOG(4) << "Function " << function.name() << " in flib_runtime_";
+ } else {
+ VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
}
return Status::OK();
}
@@ -323,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
- GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
- step_container.get());
+ GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
@@ -743,18 +744,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(
- absl::StrCat("xla_compile_graph_", name), *graph);
+ absl::StrCat("xla_compile_graph_", name), *graph,
+ flib_runtime_->GetFunctionLibraryDefinition());
}
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
- // Converts Tensorflow's graph control-flow constructs into functional
- // control-flow that can be compiled into XLA code.
- TF_RETURN_IF_ERROR(
- FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
- graph.get(), local_flib_def_.get()));
-
// Detect invalid nodes.
// FunctionalizeControlFlow may remove some nodes from the graph.
TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 40ce9fb41c..72b17d04fc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler doesn't reorder the parameters.
@@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(
- xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR0<int32>(7);
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@@ -619,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
+ instr1.clear_id();
+ instr1.clear_operand_ids();
instr2.clear_name();
- // The names of instructions were uniquified by the XlaBuilder, the rest
- // of the fields should be identical.
+ instr2.clear_id();
+ instr2.clear_operand_ids();
+ // The names of instructions were uniquified by the XlaBuilder and the
+ // unique ids may be different, the rest of the fields should be
+ // identical.
string str1, str2;
+ LOG(INFO) << "instr1 = " << instr1.DebugString();
+ LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);
@@ -672,34 +664,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> input_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> input_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> input =
- xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*input).ConsumeValueOrDie();
+ client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> output_read =
- xla::LiteralUtil::CreateR0<int32>(42);
- std::unique_ptr<xla::Literal> output_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> output_grad1 =
- xla::LiteralUtil::CreateR1<int32>({0, 1});
- std::unique_ptr<xla::Literal> output_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
- {output_base.get(), output_grad1.get(), output_grad2.get()});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+ xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+ xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal output_resource =
+ xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -866,29 +850,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({5, 144});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
@@ -952,20 +931,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@@ -1069,29 +1045,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1138,29 +1112,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.
@@ -1255,25 +1226,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
- status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
- std::move(graph_copy), args, &result);
- ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- absl::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: {{node NoOp}}"))
- << status.error_message();
- }
-
- // Fix control edges for NoOp.
- {
- std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
- CopyGraph(*graph, graph_copy.get());
- EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
- XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
- EXPECT_EQ(0, result.resource_updates.size());
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 636cb71e21..2a9eaeee14 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
+DataType XlaOpKernelContext::InputType(absl::string_view name) {
+ return GetInputTensorByName(name).dtype();
+}
+
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
@@ -102,8 +106,7 @@ Status XlaOpKernelContext::ConstantInput(int index,
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
- TF_RETURN_IF_ERROR(context->op_kernel().InputRange(
- StringPiece(name.data(), name.length()), &start, &stop));
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
@@ -214,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
- xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
- &layout);
+ xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+ constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
- "as a compile-time constant.\nError: ",
+ " as a compile-time constant.\nError: ",
computed.status().error_message());
}
- *constant_literal = std::move(*computed.ValueOrDie());
+ *constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}
@@ -366,8 +368,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
- TF_RETURN_IF_ERROR(
- context_->input_list(StringPiece(name.data(), name.size()), &inputs));
+ TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
@@ -380,8 +381,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
Status XlaOpKernelContext::ConstantInputList(
absl::string_view name, std::vector<xla::Literal>* outputs) {
int start, stop;
- TF_RETURN_IF_ERROR(op_kernel().InputRange(
- StringPiece(name.data(), name.size()), &start, &stop));
+ TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
@@ -615,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
- CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok());
+ CHECK(context_->input(name, &tensor).ok());
return *tensor;
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 962c86d3a5..a3a0d10cc0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -71,6 +71,9 @@ class XlaOpKernelContext {
// Returns the type of input `index`.
DataType input_type(int index) const;
+ // Returns the type of input `name`.
+ DataType InputType(absl::string_view name);
+
// Returns the type of input `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
@@ -79,7 +82,7 @@ class XlaOpKernelContext {
// Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns the shape of input `name`.
+ // Returns the shape of input with name `name`.
TensorShape InputShape(absl::string_view name);
// Returns input `index` as a XlaOp. Unlike
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 5d53169f68..74a4885f1f 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -22,7 +22,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "absl/strings/string_view.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 76e36f3c46..ef70c1f8ac 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -193,6 +193,7 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/synchronization",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 8818f81312..5dde5b432f 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
-StatusOr<std::unique_ptr<Literal>> Client::Transfer(
- const GlobalData& data, const Shape* shape_with_layout) {
+StatusOr<Literal> Client::Transfer(const GlobalData& data,
+ const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
@@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
@@ -162,7 +162,7 @@ Status Client::ResetDevice() {
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
@@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
- const XlaComputation& computation, const Layout* output_layout) const {
+StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
+ const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 7960b07868..6f4d33c469 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -96,8 +96,8 @@ class Client {
//
// If shape_with_layout is not nullptr, it points to a shape whose layout will
// be the layout of the returned literal.
- StatusOr<std::unique_ptr<Literal>> Transfer(
- const GlobalData& data, const Shape* shape_with_layout = nullptr);
+ StatusOr<Literal> Transfer(const GlobalData& data,
+ const Shape* shape_with_layout = nullptr);
// Transfer the given literal to the server. This allocates memory on the
// device and copies the literal's contents over. Returns a global data handle
@@ -122,7 +122,7 @@ class Client {
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+ StatusOr<Literal> TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
@@ -132,7 +132,7 @@ class Client {
// Executes the computation with the given arguments and transfers the result
// to the client as a literal. Parameters are defined the same as for
// Execute() and Transfer().
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
@@ -153,7 +153,7 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ StatusOr<Literal> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 6861521acc..25cc37edc4 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (DataSizeOfShape(shape) < (1LL << 20)) {
- StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
+ StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
@@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
- return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
+ return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 4402ba8762..f96b6c9c26 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*argument));
- *hlo_snapshot->add_arguments() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
+ *hlo_snapshot->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*result));
- *hlo_snapshot->mutable_result() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
+ *hlo_snapshot->mutable_result() = literal.ToProto();
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
+StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream,
backend_->BorrowStream(shaped_buffer.device_ordinal()));
@@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
return std::move(scoped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
shaped_buffer.device_ordinal()));
@@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
literal);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal) {
+StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, shape, literal.get()));
+ executor, shape, &literal));
return std::move(literal);
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 56c3a3da02..feb2f8ec9d 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -84,8 +84,7 @@ class LocalExecutable {
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
- StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
@@ -132,8 +131,7 @@ class LocalClient : public Client {
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
- StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
@@ -151,8 +149,8 @@ class LocalClient : public Client {
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal);
+ StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal);
// Returns the device ordinal that corresponds to the given replica number.
//
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 887b970661..95ff6432a5 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
- TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
+ LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
- *program_shape.mutable_result() = instructions_[root_id].shape();
+ *program_shape.mutable_result() = root_proto->shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
return;
}
- CHECK(op_handle < instructions_.size() && op_handle >= 0);
-
- const HloInstructionProto& instr = instructions_[op_handle];
+ const HloInstructionProto& instr =
+ *(LookUpInstructionByHandle(op_handle).ValueOrDie());
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
@@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
// Clear data held by this builder.
this->instructions_.clear();
+ this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
@@ -738,7 +739,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
@@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
- // also a valid denpendency order). The related ops will be added to the
+ // also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
@@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
- int64 node = worklist.front();
+ int64 handle = worklist.front();
worklist.pop();
- for (int64 id : instructions_[node].operand_ids()) {
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+ LookUpInstructionByHandle(handle));
+ for (int64 id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
- for (int64 called_id : instructions_[node].called_computation_ids()) {
+ for (int64 called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
}
@@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// Add related ops to the computation.
for (int64 id : related_ops) {
auto* instr = entry.add_instructions();
- *instr = instructions_[id];
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
+ LookUpInstructionByHandle(id));
+ *instr = *instr_src;
// Ensures that the instruction names are unique among the graph.
const string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@@ -2415,11 +2420,11 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
- const int64 handle = instructions_.size();
+ const int64 handle = GetUniqueId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode()));
+ instr.set_name(instr.opcode());
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
@@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
*instr.mutable_sharding() = *sharding_;
}
- instructions_.push_back(instr);
+ handle_to_index_[handle] = instructions_.size();
+ instructions_.push_back(std::move(instr));
XlaOp op(handle, this);
return op;
@@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
op.handle(), op.builder_->name(), this->name());
}
- if (op.handle() >= instructions_.size() || op.handle() < 0) {
- return InvalidArgument("no XlaOp value %d", op.handle());
+ return LookUpInstructionByHandle(op.handle());
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
+ int64 handle) const {
+ auto it = handle_to_index_.find(handle);
+ if (it == handle_to_index_.end()) {
+ return InvalidArgument("No XlaOp with handle %d", handle);
}
- return &instructions_[op.handle()];
+ return &instructions_[it->second];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 58e8f4e7fa..d0c59fa6f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+ int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
+ // A map from XlaOp::Handle to the index in the instructions_ vector where the
+ // instruction is held.
+ tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
@@ -2112,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2129,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
+ return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2189,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2207,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2221,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
+ LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2236,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2253,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3f7635bd40..5035f41988 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = absl::make_unique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+ Literal literal(shape);
+ literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+ const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = absl::make_unique<Literal>(proto.shape());
+ Literal literal(proto.shape());
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = absl::make_unique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
+ Literal result(new_shape);
+ TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
- result->shape(),
+ result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
+ TF_CHECK_OK(result.CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
@@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
- char* dest_data = static_cast<char*>(result->untyped_data());
+ char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
- std::unique_ptr<Literal> output;
+ Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
- output = CloneToUnique();
+ output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape_do_not_use() =
+ *output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
- ShapeUtil::HumanString(output->shape()));
+ ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- absl::Span<const int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = absl::make_unique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ Literal new_literal(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
+Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
- auto result_literal = absl::make_unique<Literal>(result_shape);
+ Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
+ result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
+ result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
-std::unique_ptr<Literal> LiteralBase::Slice(
- absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
return result;
}
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = absl::make_unique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+ const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
+ Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
+ auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(
+ Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
- absl::Span<complex64> dest_data = result_literal->data<complex64>();
+ absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
PrimitiveType_Name(primitive_dest_type));
}
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
+ return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
@@ -1362,17 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
@@ -1381,11 +1361,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
+ elements.push_back(std::move(new_element));
}
- auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
- return std::move(converted);
+ return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
@@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
+ case S8:
+ proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
+ element_count());
+ break;
case U8:
proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
element_count());
@@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
+ case S8: {
+ auto s8_data = data<int8>();
+ TF_RET_CHECK(proto.s8s().size() == s8_data.size());
+ std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
+ } break;
case U8: {
auto u8_data = data<uint8>();
TF_RET_CHECK(proto.u8s().size() == u8_data.size());
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index b928cb6374..1e0a2ad0dd 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -217,31 +217,20 @@ class LiteralBase {
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
+ // Clones the underlying buffers into a new Literal.
Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -259,24 +248,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
+ Literal Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+ Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape, absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Broadcast(const Shape& result_shape,
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,7 +273,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
+ Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -293,15 +281,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const;
+ Literal Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
+ Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@@ -312,7 +300,7 @@ class LiteralBase {
// initialization, then reinitialization. Conside if a call to
// absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+ static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@@ -539,8 +527,8 @@ class LiteralBase {
private:
template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape, absl::Span<const int64> start_indices) const;
+ Literal SliceInternal(const Shape& result_shape,
+ absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
@@ -1157,7 +1143,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
+ literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 1a64594db8..7ad287c897 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test {
Layout layout_r3_dim0minor_;
Layout layout_r4_dim0major_;
Layout layout_r4_dim0minor_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+ Literal literal_r4_2x2x3x3_dim0major_;
+ Literal literal_r4_2x2x3x3_dim0minor_;
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto true_lit = LiteralUtil::CreateR0<bool>(true);
- EXPECT_EQ("true", true_lit->ToString());
+ EXPECT_EQ("true", true_lit.ToString());
auto false_lit = LiteralUtil::CreateR0<bool>(false);
- EXPECT_EQ("false", false_lit->ToString());
+ EXPECT_EQ("false", false_lit.ToString());
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
- EXPECT_EQ("42", u32_lit->ToString());
+ EXPECT_EQ("42", u32_lit.ToString());
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
- EXPECT_EQ("-999", s32_lit->ToString());
+ EXPECT_EQ("-999", s32_lit.ToString());
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
- EXPECT_EQ("3.14", f32_lit->ToString());
+ EXPECT_EQ("3.14", f32_lit.ToString());
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
- EXPECT_EQ("0.5", f16_lit->ToString());
+ EXPECT_EQ("0.5", f16_lit.ToString());
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
- EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString());
+ EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- EXPECT_EQ("0.5", bf16_lit->ToString());
+ EXPECT_EQ("0.5", bf16_lit.ToString());
// 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- EXPECT_EQ("9", bf16_lit_truncated2->ToString());
+ EXPECT_EQ("9", bf16_lit_truncated2.ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
- EXPECT_EQ("{101}", pred_vec->ToString());
+ EXPECT_EQ("{101}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
@@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) {
{ 3, 4 },
{ 5, 6 }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
@@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) {
{ { 5 },
{ 6 } }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, TupleToString) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -171,7 +171,7 @@ f32[2,2] {
{ 3, 4 }
}
))";
- EXPECT_EQ(expected, tuple->ToString());
+ EXPECT_EQ(expected, tuple.ToString());
}
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
// clang-format on
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[2,3,2] {
{ { 1, 2 },
{ 3, 4 },
@@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) {
};
std::vector<int64> expected_values = {8, 9, 7, 10};
- EXPECT_EQ(literal->sparse_indices()->data(),
+ EXPECT_EQ(literal.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
+ EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
{2001, 2002},
}, /*projection_p=*/1, /*projection_z=*/2);
// clang-format on
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[1,2,3,2] {
{ /*i0=0*/
{ /*i1=0*/
@@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
}
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
- EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+ EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
ElementsAre(2, 2, 3, 3));
- string result = literal_r4_2x2x3x3_dim0major_->ToString();
+ string result = literal_r4_2x2x3x3_dim0major_.ToString();
const string expected = R"(f32[2,2,3,3] {
{ /*i0=0*/
{ /*i1=0*/
@@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
});
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
- literal->EachCellAsString(
+ literal.EachCellAsString(
[&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) {
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
- EXPECT_EQ(*f32_42, *f32_42);
- EXPECT_EQ(*f32_42, *f32_42_clone);
+ EXPECT_EQ(f32_42, f32_42);
+ EXPECT_EQ(f32_42, f32_42_clone);
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*f32_42, *f32_123);
+ EXPECT_NE(f32_42, f32_123);
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
- EXPECT_NE(*f32_42, *f64_42);
+ EXPECT_NE(f32_42, f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(*matrix, *matrix);
- EXPECT_EQ(*matrix, *matrix_clone);
- EXPECT_NE(*matrix, *matrix_different);
- EXPECT_NE(*matrix, *vector_literal);
- EXPECT_NE(*matrix, *scalar);
- EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(matrix, matrix);
+ EXPECT_EQ(matrix, matrix_clone);
+ EXPECT_NE(matrix, matrix_different);
+ EXPECT_NE(matrix, vector_literal);
+ EXPECT_NE(matrix, scalar);
+ EXPECT_NE(matrix, nil);
EXPECT_EQ(nil, nil);
}
@@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) {
auto token1 = LiteralUtil::CreateToken();
auto scalar = LiteralUtil::CreateR0<float>(1.0);
- EXPECT_EQ(*token0, *token1);
- EXPECT_NE(*token0, *scalar);
+ EXPECT_EQ(token0, token1);
+ EXPECT_NE(token0, scalar);
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
- *LiteralUtil::MakeTuple({token0.get()}));
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+ LiteralUtil::MakeTuple({&token0}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&token1, &scalar}));
+ EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&scalar, &token1}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
- colmajor->Set<float>({0, 0}, 1.0);
- colmajor->Set<float>({0, 1}, 2.0);
- colmajor->Set<float>({1, 0}, 3.0);
- colmajor->Set<float>({1, 1}, 4.0);
+ Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ colmajor.Set<float>({0, 0}, 1.0);
+ colmajor.Set<float>({0, 1}, 2.0);
+ colmajor.Set<float>({1, 0}, 3.0);
+ colmajor.Set<float>({1, 1}, 4.0);
- auto rowmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
- rowmajor->Set<float>({0, 0}, 1.0);
- rowmajor->Set<float>({0, 1}, 2.0);
- rowmajor->Set<float>({1, 0}, 3.0);
- rowmajor->Set<float>({1, 1}, 4.0);
+ Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ rowmajor.Set<float>({0, 0}, 1.0);
+ rowmajor.Set<float>({0, 1}, 2.0);
+ rowmajor.Set<float>({1, 0}, 3.0);
+ rowmajor.Set<float>({1, 1}, 4.0);
- EXPECT_EQ(*rowmajor, *colmajor);
+ EXPECT_EQ(rowmajor, colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
- auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_EQ(*tuple1, *tuple2);
+ auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+ EXPECT_EQ(tuple1, tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_NE(*tuple1, *reversed_tuple);
+ auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+ EXPECT_NE(tuple1, reversed_tuple);
// Tuple with different value.
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
- auto different_tuple =
- LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_NE(*tuple1, *different_tuple);
+ auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+ EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
@@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) {
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- EXPECT_EQ(*vector, *vector_clone);
+ EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
- EXPECT_NE(*vector, *vector_reversed);
+ EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+ auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
// Tuples should always return false for IsAll.
- EXPECT_FALSE(tuple->IsAll(0));
- EXPECT_FALSE(tuple->IsAll(1));
+ EXPECT_FALSE(tuple.IsAll(0));
+ EXPECT_FALSE(tuple.IsAll(1));
}
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto scalar = LiteralUtil::CreateR0<float>(0.0);
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
- auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_EQ(*tuple, *x);
+ auto x = Literal::CreateFromShape(tuple.shape());
+ EXPECT_EQ(tuple, x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
- EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
- EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
- ->IsAll(-1));
+ .IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
- ->IsAllFloat(.5));
+ .IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(
- LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
- EXPECT_TRUE(scalar_zero->IsZero({}));
- EXPECT_FALSE(scalar_one->IsZero({}));
+ EXPECT_TRUE(scalar_zero.IsZero({}));
+ EXPECT_FALSE(scalar_one.IsZero({}));
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
- EXPECT_FALSE(array->IsZero({0, 1}));
- EXPECT_TRUE(array->IsZero({0, 2}));
- EXPECT_TRUE(array->IsZero({1, 1}));
- EXPECT_FALSE(array->IsZero({1, 2}));
+ EXPECT_FALSE(array.IsZero({0, 1}));
+ EXPECT_TRUE(array.IsZero({0, 2}));
+ EXPECT_TRUE(array.IsZero({1, 1}));
+ EXPECT_FALSE(array.IsZero({1, 2}));
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
- EXPECT_TRUE(complex_zero->IsZero({}));
- EXPECT_FALSE(complex_nonzero->IsZero({}));
+ EXPECT_TRUE(complex_zero.IsZero({}));
+ EXPECT_FALSE(complex_nonzero.IsZero({}));
}
template <typename T>
@@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
- auto data01 = data->Relayout(layout01);
- EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_EQ(*data, *data01);
+ auto data01 = data.Relayout(layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+ EXPECT_EQ(data, data01);
- auto data10 = data->Relayout(layout10);
- EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_EQ(*data, *data10);
+ auto data10 = data.Relayout(layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+ EXPECT_EQ(data, data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Transpose(/*permutation=*/{});
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}});
// clang-format on
- auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+ auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
+ reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+ EXPECT_EQ(value, original.Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
@@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
- literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+ literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
- literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+ literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->element_count(), 6);
- EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_EQ(mat_dim0minor.element_count(), 6);
+ EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
- auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+ EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->element_count(), 6);
- EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_EQ(mat_dim0major.element_count(), 6);
+ EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
- auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+ EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
ElementsAre(1, 4, 2, 5, 3, 6));
}
@@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->element_count(), 12);
+ EXPECT_EQ(lit_dim0minor.element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->data<int32>(),
+ EXPECT_THAT(lit_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
- auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+ auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+ EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->element_count(), 12);
- EXPECT_THAT(lit_dim0major->data<int32>(),
+ EXPECT_EQ(lit_dim0major.element_count(), 12);
+ EXPECT_THAT(lit_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
- auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+ auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+ EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = LiteralUtil::CreateR0<int32>(1);
- auto result = input->Slice({}, {});
- EXPECT_EQ(*input, *result);
+ auto result = input.Slice({}, {});
+ EXPECT_EQ(input, result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
- auto result = input->Slice({3}, {4});
+ auto result = input.Slice({3}, {4});
auto expected = LiteralUtil::CreateR1<float>({4.0});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto result = input_3x4->Slice({0, 2}, {2, 4});
+ auto result = input_3x4.Slice({0, 2}, {2, 4});
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_EQ(*input_2x3x2, *result);
+ auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+ EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = LiteralUtil::CreateR1<int64>({77});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
auto expected = LiteralUtil::CreateR0<float>(2.5f);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
output.PopulateWithValue<complex64>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR0<half>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
auto input = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto output = input->Replicate<uint32>(3);
+ auto output = input.Replicate<uint32>(3);
auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_EQ(*output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
auto init_proc = [&](absl::Span<const int64> indexes) {
- source->Set(indexes, ++seqnr);
+ source.Set(indexes, ++seqnr);
return true;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
init_proc);
auto blank = Literal::CreateFromShape(shape);
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
blank_indexes.begin(), std::plus<int64>());
- auto bval = blank->Get<uint32>(blank_indexes);
- matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+ auto bval = blank.Get<uint32>(blank_indexes);
+ matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
return matched;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
check_proc);
EXPECT_TRUE(matched);
}
@@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = LiteralUtil::CreateR0<uint32>(0);
auto nine = LiteralUtil::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->CopyFrom(*nine));
- EXPECT_EQ(*zero, *nine);
+ TF_EXPECT_OK(zero.CopyFrom(nine));
+ EXPECT_EQ(zero, nine);
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
- EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
- EXPECT_EQ(vect->Get<uint32>({4}), 17);
+ TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+ EXPECT_EQ(zero.Get<uint32>({}), 17);
+ TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+ EXPECT_EQ(vect.Get<uint32>({4}), 17);
}
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
- EXPECT_EQ(*nine, *const_nine);
+ TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+ EXPECT_EQ(nine, const_nine);
}
{
// Copy 0 element to destination with zero elements.
- const auto empty = Literal::CreateFromShape(empty_r1_shape);
+ auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
- EXPECT_EQ(*empty, *const_empty);
+ TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+ EXPECT_EQ(empty, const_empty);
}
}
@@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
TEST_F(LiteralUtilTest, CopyFromArrays) {
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*scalar_42, *scalar_123);
- TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*scalar_42, *scalar_123);
- EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+ EXPECT_NE(scalar_42, scalar_123);
+ TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(scalar_42, scalar_123);
+ EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
- EXPECT_NE(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
- TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+ EXPECT_NE(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {matrix.get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get()});
+ Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0})};
+ Literal inner_tuple = LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal});
+ Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(-5).get(),
- LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+ Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+ Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+ Literal tuple =
+ LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
// Overwrite the inner tuple element of nested_tuple with the contents of
// 'tuple'.
- TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{}));
+ TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
- LiteralUtil::CreateR0<int32>(4).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+ LiteralUtil::CreateR0<int32>(4)};
+ Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
// Copy from one element to the other.
- TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{0}));
+ TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
- Status status = matrix->CopyFrom(*vector);
+ Status status = matrix.CopyFrom(vector);
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Destination subshape incompatible"));
@@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
- auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
- Literal* l1 = m1.get();
- const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+ Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+ const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
@@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l2 = m2.get();
- const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+ const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
@@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->Populate<uint32>(generator));
+ TF_EXPECT_OK(literal.Populate<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+ TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// clang-format on
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->Convert(U32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
- EXPECT_EQ(*expected, *converted);
+ EXPECT_EQ(expected, converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
- std::unique_ptr<Literal> conv;
+ Literal conv;
- conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u32);
+ conv = s8.Convert(U32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u32);
- conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = s8.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u64);
+ conv = s8.Convert(U64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u64);
- conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s64);
+ conv = s8.Convert(S64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s64);
- conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *pred);
+ conv = s8.Convert(PRED).ConsumeValueOrDie();
+ EXPECT_EQ(conv, pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = bf16.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = bf16.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *int32_pred);
+ conv = pred.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, int32_pred);
- conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f32.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f64.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = s32.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f64.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = s32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = u32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = s32.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- conv = f16->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = f16.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- EXPECT_EQ(s32->Convert(TUPLE).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(S16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(U16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(F32).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(S32).status().code(),
+ EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) {
tensorflow::bit_cast<uint32>(100.f), 0xbeef});
auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->BitcastConvert(F32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
- Status status = literal->BitcastConvert(F64).status();
+ Status status = literal.BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(
absl::StrContains(status.error_message(), "bit widths are different"));
@@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
p.add_preds((i % 2) == (len % 2));
}
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- ASSERT_EQ(len, literal->data<bool>().size());
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal.data<bool>().size());
int i = 0;
- for (bool value : literal->data<bool>()) {
+ for (bool value : literal.data<bool>()) {
EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
@@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h2(2.0f);
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->data<half>().size());
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+ EXPECT_EQ(4, m.data<half>().size());
- LiteralProto p = l->ToProto();
+ LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
@@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- auto r = literal->data<half>();
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ auto r = literal.data<half>();
ASSERT_EQ(4, r.size());
EXPECT_EQ(h1, r[0]);
EXPECT_EQ(h2, r[1]);
@@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+ EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+ EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 1.0f);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
1.0f);
- nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 555.0f);
+ nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
555.0f);
@@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view,
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
}
TEST_F(LiteralUtilTest, LiteralMove) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- Literal literal(std::move(*matrix));
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(matrix));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get(),
- &nil_literal});
-
- EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
- std::vector<Literal> elements = nested_tuple->DecomposeTuple();
- EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+ Literal inner_elements[] = {
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}),
+ };
+ Literal tuple_elements[] = {
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+ LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal}),
+ };
+ Literal nested_tuple = LiteralUtil::MakeTuple(
+ {&tuple_elements[0], &tuple_elements[1], &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+ std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
ASSERT_EQ(elements.size(), 3);
@@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
- elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
- ));
+ elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+ elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+ std::vector<Literal> inner_elements;
+ inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+ inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+ elements.push_back(
+ LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
Literal literal;
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- literal = std::move(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(matrix);
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
}
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralSlice(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralSlice(matrix);
LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(42.0).get(),
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
- tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
-
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
- 3.0);
- tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ Literal elements[] = {
+ LiteralUtil::CreateR0<float>(42.0),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ };
+ auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+ tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-4.0);
}
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
// Literals constructed using CreateFromShape should be zero initialized.
- std::unique_ptr<Literal> scalar_f32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
- EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
- EXPECT_TRUE(scalar_f32->IsAll(0));
-
- std::unique_ptr<Literal> vector_s32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
- EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
- EXPECT_TRUE(vector_s32->IsAll(0));
-
- std::unique_ptr<Literal> tuple =
- Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
- ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
-
- EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
- EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
- EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
- EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+ Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32.IsAll(0));
+
+ Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32.IsAll(0));
+
+ Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1657,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
@@ -1665,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto matrix_pred =
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
auto tuple = LiteralUtil::MakeTuple(
- {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+ auto nested_tuple =
+ LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
- return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
};
- EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
- EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
- EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
- EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
- EXPECT_EQ(*tuple, to_from_proto(*tuple));
- EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
+ EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
+ EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+ EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+ EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+ EXPECT_EQ(tuple, to_from_proto(tuple));
+ EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
- EXPECT_NE(*one_f32, *two_f32);
- EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+ EXPECT_NE(one_f32, two_f32);
+ EXPECT_NE(one_f32, to_from_proto(two_f32));
}
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1802,11 +1788,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
- literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
- literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
- literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
- literal->SortSparseElements();
- EXPECT_EQ(literal->ToString(false),
+ literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal.SortSparseElements();
+ EXPECT_EQ(literal.ToString(false),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
@@ -1816,57 +1802,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
"false");
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(int64{2}));
EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(double{2.0}));
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(static_cast<float>(half{2.0})));
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+ Literal literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
- /*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 613449cf10..0cb1ae35f4 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -45,7 +45,7 @@ using absl::StrCat;
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+Literal ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
+ auto dest = result.data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
+ TF_CHECK_OK(result.CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
}
}
});
@@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
+/* static */ Literal LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
+/* static */ Literal LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
+/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
- return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
+/* static */ Literal LiteralUtil::CreateToken() {
+ return Literal(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(0));
+ return LiteralUtil::CreateR0<uint8>(0);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(0));
+ return LiteralUtil::CreateR0<uint32>(0);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(0));
+ return LiteralUtil::CreateR0<uint64>(0);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(0));
+ return LiteralUtil::CreateR0<int8>(0);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(0));
+ return LiteralUtil::CreateR0<int32>(0);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(0));
+ return LiteralUtil::CreateR0<int64>(0);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(0));
+ return LiteralUtil::CreateR0<float>(0);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(0));
+ return LiteralUtil::CreateR0<double>(0);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(0));
+ return LiteralUtil::CreateR0<complex64>(0);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(1));
+ return LiteralUtil::CreateR0<uint8>(1);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(1));
+ return LiteralUtil::CreateR0<uint32>(1);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(1));
+ return LiteralUtil::CreateR0<uint64>(1);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(1));
+ return LiteralUtil::CreateR0<int8>(1);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(1));
+ return LiteralUtil::CreateR0<int32>(1);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(1));
+ return LiteralUtil::CreateR0<int64>(1);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(1));
+ return LiteralUtil::CreateR0<float>(1);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(1));
+ return LiteralUtil::CreateR0<double>(1);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(1));
+ return LiteralUtil::CreateR0<complex64>(1);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- -std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- -std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity());
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+/* static */ Literal LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = absl::make_unique<Literal>(
+ Literal literal(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
- absl::string_view value) {
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
+ Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
- literal->Set<uint8>({i}, value[i]);
+ literal.Set<uint8>({i}, value[i]);
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
- float from, float to, int64 rows, int64 cols) {
+/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
+ int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
+/* static */ Literal LiteralUtil::ReshapeSlice(
absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
@@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
- auto new_literal = absl::make_unique<Literal>(
+ Literal new_literal(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
+ Shape shape_with_layout = new_literal.shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
@@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
+ new_literal.Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
break;
case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
+ new_literal.Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
break;
case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
+ new_literal.Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
break;
case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
+ new_literal.Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
break;
case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
+ new_literal.Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
break;
case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
+ new_literal.Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
break;
case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
+ new_literal.Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
break;
case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
+ new_literal.Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
break;
case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
+ new_literal.Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
@@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
switch (literal.shape().element_type()) {
case PRED:
- return std::move(
- *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
+ return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
// 8 bit types.
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
+ return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
+ return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
// 16 bit types.
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- literal.GetFirstElement<bfloat16>()));
+ return LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>());
case F16:
- return std::move(
- *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
+ return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
case S16:
- return std::move(
- *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
+ return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
case U16:
- return std::move(
- *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
+ return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
// 32 bit types.
case F32:
- return std::move(
- *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
+ return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
+ return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
+ return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
// 64 bit types.
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(
- literal.GetFirstElement<complex64>()));
+ return LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>());
case F64:
- return std::move(
- *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
+ return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
+ return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
+ return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
+/* static */ Literal LiteralUtil::MakeTuple(
absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
+/* static */ Literal LiteralUtil::MakeTupleFromSlices(
absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements) {
+/* static */ Literal LiteralUtil::MakeTupleOwned(
+ std::vector<Literal> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
for (const auto& element : elements) {
- element_shapes.push_back(element->shape());
+ element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
- literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2d6084a67a..2b181621ed 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -69,36 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR0(NativeT value);
+ static Literal CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
- static std::unique_ptr<Literal> CreateR1(
- const tensorflow::core::Bitmap& values);
+ static Literal CreateR1(absl::Span<const NativeT> values);
+ static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2(
+ static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2WithLayout(
+ static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3(
- std::initializer_list<
- std::initializer_list<std::initializer_list<NativeT>>>
- values);
+ static Literal CreateR3(std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3WithLayout(
+ static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4(
+ static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4WithLayout(
+ static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -139,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
- static std::unique_ptr<Literal> CreateSparse(
- absl::Span<const int64> dimensions, SparseIndexArray indices,
- absl::Span<const NativeT> values, bool sort = true);
+ static Literal CreateSparse(absl::Span<const int64> dimensions,
+ SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -155,130 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
+ static Literal CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
+ static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2D(
- const Array2D<NativeT>& values);
+ static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
+ static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3D(
- const Array3D<NativeT>& values);
+ static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
+ static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4D(
- const Array4D<NativeT>& values);
+ static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
+ static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
+ static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
- static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
- int64 rows, int64 cols);
+ static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+ int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3Projected(
+ static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
+ static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
- static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+ static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
- static std::unique_ptr<Literal> MakeTuple(
- absl::Span<const Literal* const> elements);
+ static Literal MakeTuple(absl::Span<const Literal* const> elements);
- static std::unique_ptr<Literal> MakeTupleFromSlices(
- absl::Span<const LiteralSlice> elements);
+ static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
- // std::vector<std::unique_ptr<Literal>> elements = ...;
+ // std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements);
+ static Literal MakeTupleOwned(std::vector<Literal> elements);
- // This overload lets you pass a braced list of unique_ptr<Literal>s to
+ // This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
- // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+ // Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
- // The arguments to this function must all be unique_ptr<Literal>.
+ // The arguments to this function must all be Literal.
template <typename... Ts>
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::unique_ptr<Ts>... elements) {
- std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
- std::move(elements)...};
- std::vector<std::unique_ptr<Literal>> v;
+ static Literal MakeTupleOwned(Ts... elements) {
+ std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+ std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
- static std::unique_ptr<Literal> CreateToken();
+ static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type, absl::Span<const int64> dimensions);
+ static Literal CreateFromDimensions(PrimitiveType primitive_type,
+ absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(
- const LiteralSlice& bf16_literal);
+ static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(
- const LiteralSlice& f32_literal);
+ static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
- static std::unique_ptr<Literal> ReshapeSlice(
- absl::Span<const int64> new_dimensions,
- absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
+ static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major,
+ const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -286,7 +275,7 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator);
@@ -297,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+ T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -307,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+ T stddev);
//
// End of factory methods.
@@ -322,44 +311,43 @@ class LiteralUtil {
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
- literal->Set({}, value);
+ literal.Set({}, value);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- absl::Span<const NativeT> values) {
- auto literal = absl::make_unique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+ Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateR2(values);
+ literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -384,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -422,23 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
+/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
- indices.max_indices()));
- literal->PopulateSparse(indices, values, sort);
+ Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+ indices.max_indices()));
+ literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -446,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateFromArray(values);
+ literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -514,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -542,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -565,33 +549,29 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
- NativeT value) {
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
- literal->PopulateWithValue(value);
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value) {
+ Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ Literal literal(shape);
+ TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index bddb664149..0f86f9f35e 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -40,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
-StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
- const Shape& shape, const Layout* layout) {
+StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
+ const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
@@ -58,14 +57,14 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
PrimitiveType_Name(shape.element_type()));
}
- auto result = absl::make_unique<Literal>(literal_shape);
- result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+ Literal result(literal_shape);
+ result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
- absl::Span<const float> field = result->data<float>();
+ absl::Span<const float> field = result.data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -86,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 98dccaa9a2..d6d2ff1521 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -41,8 +41,7 @@ class PackedLiteralReader {
//
// Layout is optional. If it is not provided, no layout is set on the literal
// that is produced.
- StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
- const Layout* layout = nullptr);
+ StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
// Returns whether the input file has been fully exhausted; i.e. all available
// packed literals have been read and we're at the end of the file.
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index 787725e884..b507a2ef79 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
@@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) {
return safe_file_name;
}
+std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
+GetDirectoryExpanders() {
+ static auto* mutex = new tensorflow::mutex;
+ static auto* singleton = new std::vector<std::function<string(string)>>;
+ return {mutex, singleton};
+}
+
+// Runs all the directory expanders over x and returns the result.
+string Expand(string x) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ for (const auto& f : *pair.second) {
+ x = f(x);
+ }
+ return x;
+}
+
} // namespace
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+ string expanded_dir = Expand(directory);
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir));
string safe_file_name = SanitizeFileName(file_name) + ".pb";
- const string path = tensorflow::io::JoinPath(directory, safe_file_name);
+ const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name);
return tensorflow::WriteBinaryProto(env, path, message);
}
+void RegisterDirectoryExpander(const std::function<string(string)>& expander) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ pair.second->push_back(expander);
+}
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 3667621367..f22fc8b849 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
+// Registers a function that may either expand a dirpath or forward the original
+// dirpath along as-is.
+void RegisterDirectoryExpander(const std::function<string(string)>& expander);
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd6e20b693..9da5dc0d2d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
return client->TransferToInfeedLocal(literal, device_ordinal);
}
-StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number) {
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
<< " shape: " << shape;
LocalClient* client = GetOrCreateLocalClient();
@@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
LocalClient* client = GetOrCreateLocalClient();
StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- return ToBuffer(client, /*device_ordinal=*/0, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ return ToBuffer(client, /*device_ordinal=*/0, relaid);
}
return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
@@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
-StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
+StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer());
}
@@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
-StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
+StatusOr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
@@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
// Each replica populates a StatusOr result, but only replica zero actually
// retrieves its literal value.
- std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+ std::vector<StatusOr<Literal>> results(GetReplicaCount());
{
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
GetReplicaCount());
@@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
StatusOr<ScopedShapedBuffer> pushed;
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- pushed = ToBuffer(client, device_ordinal, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 78b3c598b9..1d5dfe5911 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
-StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number);
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@@ -65,7 +65,7 @@ class LocalShapedBuffer {
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
- StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+ StatusOr<Literal> ToLiteral() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
@@ -117,7 +117,7 @@ class CompiledLocalComputation {
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
- StatusOr<std::unique_ptr<Literal> > Execute(
+ StatusOr<Literal> Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape> >& shapes_with_layout);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 450d3fe5af..521490e76c 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
- std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
+ Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
// Literal
-%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
+%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
SWIG_fail;
}
- $1 = literal_status.ValueOrDie().get();
+ $1 = &literal_status.ValueOrDie();
}
-%typemap(out) std::unique_ptr<Literal> {
+%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
- $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+ $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
+ StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
SWIG_fail;
}
- temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
+ temps.push_back(literal_status.ConsumeValueOrDie());
Py_DECREF(o);
}
$1 = &temps;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index fc6511bef5..b0aa024c74 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
}
-StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
@@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
int np_type = PyArray_TYPE(py_array);
auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
- TF_RETURN_IF_ERROR(
- CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
+ TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
return std::move(literal);
} else {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 8cae175185..40ff2d9ad2 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
-StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 9f1afa2671..ceb5e74db7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
- float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
@@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
const Array4D<float>& source,
float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
@@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@@ -565,7 +558,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
*window.add_dimensions() = dim2;
const Shape& shape = ShapeInference::InferConvolveShape(
- lhs_literal->shape(), rhs_literal->shape(),
+ lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, window, dnums)
.ConsumeValueOrDie();
@@ -585,18 +578,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
- CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+ CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
- absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+ result_literal.shape().dimensions(1),
+ result_literal.shape().dimensions(2),
+ result_literal.shape().dimensions(3));
result->Each([&](absl::Span<const int64> indices, float* value) {
- *value = result_literal->Get<float>(indices);
+ *value = result_literal.Get<float>(indices);
});
return result;
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 9ce098029d..8654fbb9b5 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -177,47 +177,41 @@ class ReferenceUtil {
// Windowed reductions with Add as the function to apply.
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
- const absl::Span<const float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const float> operand, float init,
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
// Windowed reductions with a generic reduce function.
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
// With arbitrary padding.
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
// Batch normalize data.
static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -230,8 +224,8 @@ class ReferenceUtil {
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, bool same_padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ bool same_padding);
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -332,8 +326,8 @@ class ReferenceUtil {
// Slices with index clamping
template <typename T>
- static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
- int64 start, int64 size) {
+ static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
+ int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 3ec0192148..a1b0f4045f 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
- LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({0}, result);
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
+ LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
ErrorSpec(0.0001));
}
@@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
@@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray3D) {
@@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+ {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
ErrorSpec(0.0001));
}
@@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
- *actual_literal, ErrorSpec(0.0001));
+ {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray4D) {
@@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
} // namespace
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 43fd8fe1bd..84fe5b17d1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<float>(expected);
+ Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
ErrorSpec(0.0001)));
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e784663ff6..fb80c78f68 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -87,6 +87,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -123,6 +124,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -352,6 +354,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -402,6 +405,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -498,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -546,6 +551,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -568,6 +574,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1012,8 +1019,8 @@ cc_library(
":buffer_value_containers",
":heap_simulator",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_proto",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1041,8 +1048,8 @@ tf_cc_test(
":cpu_plugin",
":flatten_call_graph",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1088,8 +1095,8 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1131,6 +1138,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -1139,6 +1147,37 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_module_group",
+ srcs = ["hlo_module_group.cc"],
+ hdrs = ["hlo_module_group.h"],
+ deps = [
+ ":hlo",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_module_group_test",
+ srcs = ["hlo_module_group_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_module_group",
+ ":hlo_parser",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_module_group_metadata",
srcs = ["hlo_module_group_metadata.cc"],
hdrs = ["hlo_module_group_metadata.h"],
@@ -1185,9 +1224,9 @@ tf_cc_test(
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1199,13 +1238,14 @@ tf_cc_test(
)
cc_library(
- name = "hlo_scheduling",
- srcs = ["hlo_scheduling.cc"],
- hdrs = ["hlo_scheduling.h"],
+ name = "hlo_memory_scheduler",
+ srcs = ["hlo_memory_scheduler.cc"],
+ hdrs = ["hlo_memory_scheduler.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1219,15 +1259,15 @@ cc_library(
)
tf_cc_test(
- name = "hlo_scheduling_test",
- srcs = ["hlo_scheduling_test.cc"],
+ name = "hlo_memory_scheduler_test",
+ srcs = ["hlo_memory_scheduler_test.cc"],
deps = [
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1259,6 +1299,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1392,6 +1433,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -1708,6 +1750,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -1777,6 +1820,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -1953,6 +1997,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
+ ":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2236,6 +2281,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2314,6 +2360,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2394,12 +2441,11 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
- ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -2428,6 +2474,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -2494,6 +2541,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2611,6 +2659,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -2888,6 +2937,7 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3d18fe3be2..5458159d14 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
return scalar_add_computation_;
}
+ // Tries to fold a kPad in the input or filter into the convolution
+ // instruction's window.
+ StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+ StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+ // Tries to use a kDot in place of the given convolution.
+ StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplification on platforms where it causes a slowdown.
+ // Disable convolution -> dot simplification on platforms where it causes a
+ // slowdown.
bool enable_conv_simplification_;
// Cached computation for adding two scalar F32.
@@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ HloInstruction::CreateConstant(literal.Clone()));
}
}
@@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
+ Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
- HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@@ -1469,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@@ -1572,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1607,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -2057,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (pad_literal == reduce_init_literal) {
return true;
}
- auto converted_pad_literal = pad_literal.ConvertToShape(
- reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ auto converted_pad_literal =
+ pad_literal.ConvertToShape(reduce_init_value->shape());
if (!converted_pad_literal.ok()) {
return false;
}
- return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@@ -2212,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
HloInstruction* convolution) {
- auto lhs = convolution->mutable_operand(0);
- auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
- ShapeUtil::IsZeroElementArray(rhs->shape())) {
- return ReplaceWithNewInstruction(
- convolution,
- HloInstruction::CreateBroadcast(
- convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type())
- .CloneToUnique())),
- {}));
- }
-
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
const auto& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
- // Try to merge padding/dilation of the input with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
- if (lhs->opcode() != HloOpcode::kPad) {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(lhs->operand(1), 0)) {
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
return false;
}
- const auto& padding = lhs->padding_config();
-
- // Can't pad batch or feature dims.
- for (int64 dim :
- {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
}
+ }
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = window;
- for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
- // Edge padding composes with itself in the straightforward way, but
- // composing interior padding is nontrivial, and we cowardly refuse to
- // think about it. If we see interior padding in either the kPad or conv,
- // bail if there's any sort of padding in the other.
- if (p.interior_padding() != 0 &&
- (w.padding_low() != 0 || w.padding_high() != 0 ||
- w.base_dilation() != 1)) {
- return false;
- }
- if (w.base_dilation() != 1 &&
- (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0)) {
- return false;
- }
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- w.set_padding_low(w.padding_low() + p.edge_padding_low());
- w.set_padding_high(w.padding_high() + p.edge_padding_high());
- if (p.interior_padding() != 0) {
- CHECK_EQ(w.base_dilation(), 1);
- w.set_base_dilation(1 + p.interior_padding());
- }
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs->mutable_operand(0), rhs});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
- if (folded_input_pad) {
- return Status::OK();
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
}
- // Try to merge dilation of the filter with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
- if (rhs->opcode() != HloOpcode::kPad) {
- return false;
- }
+ const auto& padding = rhs->padding_config();
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(rhs->operand(1), 0)) {
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- const auto& padding = rhs->padding_config();
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
- // Can't pad or dilate feature dims.
- for (int64 dim : {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
}
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = convolution->window();
- for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
-
- // We can only do this transformation if p adds dilation to the filter --
- // edge padding on the filter is not supported in conv.
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
- return false;
- }
-
- // Nothing to do if the kPad for this dim is entirely a nop.
- if (p.interior_padding() == 0) {
- continue;
- }
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
- // We cowardly refuse to think about how dilation composes with itself;
- // bail if both the kPad and conv have dilation on this dimension.
- if (w.window_dilation() > 1) {
- return false;
- }
- CHECK_EQ(w.window_dilation(), 1);
- w.set_window_dilation(1 + p.interior_padding());
- w.set_size(rhs->operand(0)->shape().dimensions(
- dnums.kernel_spatial_dimensions(dim)));
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
}
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs, rhs->mutable_operand(0)});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- if (folded_filter_pad) {
- return Status::OK();
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
if (!enable_conv_simplification_) {
- return Status::OK();
+ return false;
}
- // HandleConvolution tries to replace a convolution with a DOT instruction.
- //
- // Only add when bitcasts can be used:
- // - if bitcasts are not supported, then reshapes could be used but will
- // end up with another copy.
- // - if bitcasts are supported, the simplifier will be called again with
- // bitcasts_ == true.
- // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+ // layout-insensitive mode, for fear of adding nontrivial reshapes.
if (!is_layout_sensitive_) {
- return Status::OK();
+ return false;
}
const Shape& input_shape = lhs->shape();
@@ -2388,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Require the spatial dimensions in the kernel to have a bound of one.
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
- return Status::OK();
+ return false;
}
}
@@ -2399,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
- return Status::OK();
+ return false;
}
// Also, the shapes must align for a rowmajor matmul:
@@ -2425,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dnums.kernel_input_feature_dimension()) <
PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
dnums.kernel_output_feature_dimension()))) {
- return Status::OK();
+ return false;
}
auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2467,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
- return Status::OK();
+ return false;
}
auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2479,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
convolution->precision_config()));
- return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution) {
+ // Zero-sized input or filter.
+ if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+ return ReplaceWithNewInstruction(
+ convolution,
+ HloInstruction::CreateBroadcast(
+ convolution->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type()))),
+ {}));
+ }
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
+ // Try to replace the convolution with a kDot instruction.
+ TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+ if (replaced_with_dot) {
+ return Status::OK();
+ }
+
+ return Status::OK();
}
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index a0db4563fb..3fc1ba2427 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector)};
+ Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec281ae68f..30d33e0d35 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
+ elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9bb5b..f7ac8f5482 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
- for (auto* instruction : module->entry_computation()->instructions()) {
+ for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21c3b..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 933cf873e0..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 545a6ecfb1..58f78f8e24 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
- TF_ASSIGN_OR_RETURN(
- auto converted_literal,
- hlo->literal().ConvertToShape(hlo->shape(),
- /*round_f32_to_bf16=*/true));
+ TF_ASSIGN_OR_RETURN(auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 388fd5df99..e032b5c624 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 0f0af57626..65fa951afe 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 5a231c173d..795beb9ff5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -30,11 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
+ Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()})));
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 414bfe7999..17e5090505 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()});
- auto inner_tuple1 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
+ auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+ Literal element1 = LiteralUtil::CreateR0<int64>(3);
+ auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- inner_tuple0->shape(), tuple_constant, 0));
+ inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74843..34f3f914d5 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
@@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(3, call_graph->nodes().size());
@@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Test visitation of only reachable nodes.
{
@@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewModule();
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f173..e6b5665435 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, pred, "param"));
+ call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28478..96bd2616f5 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 0826380f65..0ac4a65ec6 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 2368ac8c6a..8cc522a59e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
@@ -801,6 +801,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -822,6 +823,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -946,6 +948,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -971,6 +974,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 05792795a1..2083f440fd 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {
@@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index e7b6075994..18fc144efe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -77,12 +77,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa446e..c9fb34be1c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*module), 3);
@@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6de6..be1208fb2d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 942e2ddd39..55d5925642 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -37,21 +37,20 @@ int main(int argc, char** argv) {
xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
// Transfer parameters.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@ int main(int argc, char** argv) {
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
- xla::StatusOr<std::unique_ptr<xla::Literal>> result =
- client->ExecuteAndTransfer(
- computation,
- /*arguments=*/{param0_data.get(), param1_data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/&profile);
- std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+ xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*execution_options=*/nullptr,
+ /*execution_profile=*/&profile);
+ xla::Literal actual = result.ConsumeValueOrDie();
LOG(INFO) << absl::StrFormat("computation took %dns",
profile.compute_time_ns());
- LOG(INFO) << actual->ToString();
+ LOG(INFO) << actual.ToString();
return 0;
}
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index 7d8e51f909..1a3d82de95 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
protected:
typedef std::vector<int64> Vec;
@@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
}
}
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index f11aff0573..c55206eee7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 22721051e5..1deb412064 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
protected:
CpuFusionTest() {}
@@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
- Shape vshape = input_literal1->shape();
+ Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
- error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
<< fusion_instruction2->fused_instructions_computation()->ToString();
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
- *result, error_spec_);
+ result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// Run fusion.
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
auto fusion1 = result->operand(0);
auto fusion2 = result->operand(1);
@@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The only fusion instruction should be operand 0 of the tuple (formerly
// negate1).
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c661..5cc6d01c0f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+ LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+ LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index bb105194f1..7af51db55a 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 1b3be199f6..852f34e06d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -56,9 +56,9 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
- RunTest(hlo_text, {lhs.get(), rhs.get()});
+ Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 4ed91ef187..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
- std::unique_ptr<Literal> relayed_out_literal;
+ Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->untyped_data();
+ source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6791e15ee0..64b9683628 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -173,6 +174,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
@@ -370,6 +372,8 @@ cc_library(
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
deps = [
+ ":backend_configs",
+ ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -395,6 +399,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -813,9 +818,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
)
@@ -832,6 +837,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -901,6 +907,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 05448d863d..3a23ac1d63 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,62 +31,32 @@ namespace gpu {
using se::dnn::AlgorithmDesc;
-ConvolutionThunk::ConvolutionThunk(
- CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
- int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
- : Thunk(Kind::kConvolution, hlo),
- convolution_kind_(convolution_kind),
- input_buffer_(input_buffer),
- filter_buffer_(filter_buffer),
- output_buffer_(output_buffer),
- tuple_result_buffer_(tuple_result_buffer),
- scratch_buffer_(scratch_buffer),
- input_shape_(input_shape),
- filter_shape_(filter_shape),
- output_shape_(output_shape),
- window_(window),
- dim_nums_(dim_nums),
- feature_group_count_(feature_group_count),
- algorithm_(algorithm),
- tensor_ops_enabled_(tensor_ops_enabled) {}
-
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- se::DeviceMemoryBase input_data =
- buffer_allocations.GetDeviceAddress(input_buffer_);
- se::DeviceMemoryBase filter_data =
- buffer_allocations.GetDeviceAddress(filter_buffer_);
- se::DeviceMemoryBase output_data =
- buffer_allocations.GetDeviceAddress(output_buffer_);
+ CudnnConvParams params;
+
+ params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+ params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(
- convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_,
- feature_group_count_, algorithm_config, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
void* result_ptr = [&] {
- switch (convolution_kind_) {
+ switch (params.kind) {
case CudnnConvKind::kForward:
- return output_data.opaque();
+ return params.output_buf.opaque();
case CudnnConvKind::kBackwardInput:
- return input_data.opaque();
+ return params.input_buf.opaque();
case CudnnConvKind::kBackwardFilter:
- return filter_data.opaque();
+ return params.filter_buf.opaque();
}
}();
void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c40c5..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@ limitations under the License.
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums,
- int64 feature_group_count, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 feature_group_count_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5c2555148a..f528e62b17 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/mutex.h"
@@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr) {
+ const HloCustomCallInstruction* instr) {
+ CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
+
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -216,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- if (cross_check_enabled) {
- // Broadcast a constant to the buffer, instead of zeroing the buffer. A
- // non-zero constant is useful for the cross checking, because zero-inputs
- // may not always reveal the bugs.
- const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ const auto initialize_buffer = [&stream, cross_check_enabled](
+ DeviceMemoryBase buffer) {
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
size_t left_over_bytes = buffer.size() % 4;
CHECK_EQ(0, left_over_bytes % 2);
@@ -252,33 +244,46 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
DeviceMemoryBase left_over(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
- };
- initialize_f16(input_buf);
- initialize_f16(filter_buf);
- initialize_f16(output_buf);
- } else {
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the
- // inputs might affect performance if e.g. the inputs contain denormals, and
- // this is easy enough.
- stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size());
- }
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ stream.ThenMemZero(&buffer, buffer.size());
+ }
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ TF_ASSIGN_OR_RETURN(params.input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(params.filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(params.output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
+
+ initialize_buffer(params.input_buf);
+ initialize_buffer(params.filter_buf);
+ initialize_buffer(params.output_buf);
DeviceMemoryBase* result_buf = [&] {
- switch (kind) {
+ switch (params.kind) {
case CudnnConvKind::kBackwardFilter:
- return &filter_buf;
+ return &params.filter_buf;
case CudnnConvKind::kBackwardInput:
- return &input_buf;
+ return &params.input_buf;
case CudnnConvKind::kForward:
- return &output_buf;
+ return &params.output_buf;
}
}();
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, dnums, stream_exec_);
+ input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
@@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
for (const AlgorithmDesc& alg :
- GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+ GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok =
- RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf,
- filter_buf, output_buf, &scratch_allocator, window, dnums,
- feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ params.algorithm = AlgorithmConfig(alg);
+ bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
+ &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
const bool crash_on_checking_failure =
@@ -374,34 +377,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- const auto& call_target = instr->custom_call_target();
- const auto& lhs_shape = instr->operand(0)->shape();
- const auto& rhs_shape = instr->operand(1)->shape();
- const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
- if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc =
- PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr->feature_group_count(),
- instr);
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instr->ToString();
- }
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
if (!alg_scratch_and_tc.ok()) {
LOG(ERROR) << alg_scratch_and_tc.status();
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 0cb01161b0..f79b113f8f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr);
+ const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 9bf721ecd2..228379a248 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include <cstdlib>
#include <numeric>
#include <vector>
@@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward filter.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
- HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
- std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward input.
+ // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but at least for now with version 7.1.4
+ // it is slower. This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else if (reverse_filter->IsConstant()) {
- // If the filter is a constant, we're willing to pattern-match to a
- // backwards-input conv, on the theory that
- //
- // a) reversing a constant is free, and
- // b) even if the user specified this filter as reverse(constant), we would
- // long ago have constant-folded away the reverse.
- //
- // If the constant has any other uses, reversing it isn't entirely free,
- // since we'd now have two constants to keep in memory. But hopefully it's
- // free enough.
- //
- // TODO(jlebar): Should we do this even if the filter is not a constant?
- // Reversing a non-constant filter is probably cheaper than padding the
- // input!
-
- // Nothing to do, just fall through.
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
- // This simplifies things for our caller, and algebraic-simplifier will later
- // remove any unnecessary reverses.
- if (reverse_filter->opcode() != HloOpcode::kReverse) {
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. This is the way cuDNN expects it to be.
+ dnums.set_kernel_input_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
// Create a double-reverse, which is a nop.
HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
- dnums.set_kernel_input_feature_dimension(
- conv->convolution_dimension_numbers().kernel_output_feature_dimension());
- dnums.set_kernel_output_feature_dimension(
- conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- return std::make_tuple(true, new_window, dnums);
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(true, new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ Shape new_shape = rhs->shape();
+ int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+ // In the backward convolution case, the spatial dimensions become the
+ // feature dimensions, and we are guaranteed that the spatial dimensions are
+ // adjacent.
+ CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+ int64 input_features = new_shape.dimensions(input_feature_dimension);
+ int64 output_features = new_shape.dimensions(output_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension,
+ input_features / conv->feature_group_count());
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * conv->feature_group_count());
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(true, new_window, dnums, rhs);
}
// Tries to rewrite a single convolution into a call to cudnn.
@@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
+ HloInstruction* rhs;
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
@@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
window, dnums, conv->feature_group_count());
}
- std::tie(match, window, dnums) = MatchBackwardInput(conv);
+ std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- // Backward input conv subsumes the conv plus the reverse in operand 1.
- HloInstruction* reverse = conv->mutable_operand(1);
- CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
- HloInstruction* rhs = reverse->mutable_operand(0);
-
return CreateCudnnConvBackwardInput(conv->shape(),
conv->mutable_operand(0), rhs, window,
dnums, conv->feature_group_count());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index bda8ebe579..d237f8930b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
Array4D<float> constant_arr(4, 4, 2, 2);
constant_arr.FillIota(0);
string constant_str =
- LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
ParseAndVerifyModule(absl::StrFormat(R"(
HloModule test
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9d1f..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
- output_buf, &scratch_allocator, window, dnums, feature_group_count,
- algorithm, stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index a1b4fc71d0..381aa37a1b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@ enum class CudnnConvKind {
kBackwardFilter, // input + output => filter
};
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);
@@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
//
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
@@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index ea9376e101..02a0d028c1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 59ade96f7d..b857fa775a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -24,14 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class GpuHloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089df4c..27a4d0b601 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 20d523abe0..22f43bc08b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params) {
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
+ const auto& target = custom_call->custom_call_target();
+ const auto& lhs_shape = custom_call->operand(0)->shape();
+ const auto& rhs_shape = custom_call->operand(1)->shape();
+ const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+ params->window = &custom_call->window();
+ params->dnums = &custom_call->convolution_dimension_numbers();
+ params->feature_group_count = custom_call->feature_group_count();
+ params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params->kind = CudnnConvKind::kForward;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &conv_result_shape;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params->kind = CudnnConvKind::kBackwardInput;
+ params->input_shape = &conv_result_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &lhs_shape;
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params->kind = CudnnConvKind::kBackwardFilter;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &conv_result_shape;
+ params->output_shape = &rhs_shape;
+ } else {
+ LOG(FATAL) << "Unexpected custom call target: "
+ << custom_call->custom_call_target();
+ }
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 59c65fc268..09c455cc1e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution. Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f91cc00d71..b669881026 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
- std::unique_ptr<ConvolutionThunk> thunk;
+ BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
if (target == kCudnnConvForwardCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kForward,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/conv_result_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = rhs_slice;
+ output_slice = conv_result_slice;
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardInput,
- /*input_buffer=*/conv_result_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/lhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/lhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = conv_result_slice;
+ filter_slice = rhs_slice;
+ output_slice = lhs_slice;
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardFilter,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/conv_result_slice,
- /*output_buffer=*/rhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape,
- /*output_shape=*/rhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = conv_result_slice;
+ output_slice = rhs_slice;
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
}
- thunk_sequence_->emplace_back(std::move(thunk));
+ thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+ Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+ output_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index f6325b3368..dfdcf1875d 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
- // CudnnConvolutionRewriter may add instructions of the form
- // reverse(constant), which it expects will be simplified by constant
- // folding.
- pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// pairs that TupleSimplifier fixes.
pipeline.AddPass<TupleSimplifier>();
}
+ // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
+ // instructions which can be simplified by constant folding.
+ pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index fa84d77223..b0061fa655 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
- auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9d85d746d8..2a6415d0b6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 8f0dedfa40..c4f43cc9a6 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -21,14 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index 4550f36fdf..780539c164 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30317..f8120a5fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+ TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
- LiteralUtil::CreateR0<int32>(5).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR4FromArray4D<float>(array),
+ LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 40183de96e..9a61f8ac5a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -26,9 +26,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
class WhileTransformerTest : public HloTestBase {
protected:
WhileTransformerTest()
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 00a25db467..957c4a6891 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewModule();
@@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- HloSchedule schedule(module.get());
+ HloSchedule schedule(module);
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_data, cond_lt});
schedule.set_sequence(body_computation, {body_param});
@@ -233,7 +233,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 93ec2c9438..b19ec12638 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -309,6 +309,13 @@ message HeapSimulatorTrace {
bool whole_module_simulation = 2;
}
+// An abstraction representing a set of HLO module built to run concurrently
+// across different devices.
+message HloModuleGroupProto {
+ string name = 1;
+ repeated HloModuleProto hlo_modules = 2;
+}
+
// Serialization of BufferAssignment.
message BufferAssignmentProto {
// Alias represents a source LogicalBuffer, and the buffer location that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 233d2199d1..8c6903d766 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ auto computation = absl::WrapUnique(
+ new HloComputation(proto.name(), parameter_count, &instructions, root,
+ /*fusion_instruction=*/nullptr));
+ computation->unique_id_ = proto.id();
+ return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 8a45939c61..f837816cea 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
- std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
- if (result == nullptr) {
+ if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1efc12..3e0def5d26 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
@@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
+ ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);
- EXPECT_EQ(6, module->entry_computation()
+ EXPECT_EQ(6, module()
+ .entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
- HloInstruction* add = module->computations().begin()->root_instruction();
+ ParseAndVerifyModule(kConstantFoldReduce);
+ HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);
- EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a3fcc0fefa..b76c50bb5b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(operand->shape().element_type()))));
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index eb6affadc8..e07a196d11 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<int32>(
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e09d5868f2..9b18b0284f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index d0d955fea8..06b6d5b559 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -54,9 +54,8 @@ namespace xla {
namespace {
template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
@@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
@@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
@@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
- .CloneToUnique();
+ .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloModule& module, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
@@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+ return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
@@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = input_literal->CloneToUnique();
+ evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal : arg_literals) {
+ arg_literal_ptrs.push_back(&literal);
+ }
+ return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction) {
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
- HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+ CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
- return nullptr;
+ return false;
}
- return result_or.ConsumeValueOrDie();
+ *result = result_or.ConsumeValueOrDie();
+ return true;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
- HloInstruction::CreateConstant(it->second->CloneToUnique()));
+ HloInstruction::CreateConstant(it->second->Clone()));
}
}
@@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
- HloInstruction::CreateConstant(operand.CloneToUnique());
+ HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = input_literal->CloneToUnique();
+ evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
- std::unique_ptr<Literal>* reshaped_start_indices) {
+ Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
@@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
- return std::cref(**reshaped_start_indices);
+ return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
- std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_start_indices;
+ Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
- result->CopyElementFrom(operand, input_index, output_index));
+ result.CopyElementFrom(operand, input_index, output_index));
return true;
};
@@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand.shape().dimensions(i));
+ auto operand_dim_size = operand.shape().dimensions(i);
+ auto broadcast_dim_size =
+ broadcast->shape().dimensions(broadcast->dimensions(i));
+ TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
+ "Operand dimension %d is broadcast to output dimension %d, but the "
+ "sizes of these two dims do not match (%d vs %d): %s",
+ i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
+ broadcast->ToString());
}
TF_ASSIGN_OR_RETURN(
@@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = absl::make_unique<Literal>(
- ShapeUtil::GetTupleElementShape(operand->shape(), index));
- return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
- /*dest_shape_index=*/{},
- /*src_shape_index=*/{index});
+ evaluated_[get_tuple_element] =
+ Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
- auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
- evaluated_[copy] = std::move(result);
+ evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result;
+ Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
- evaluated_[select] = on_true.CloneToUnique();
+ evaluated_[select] = on_true.Clone();
} else {
- evaluated_[select] = on_false.CloneToUnique();
+ evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
- evaluated_[tuple_select] = on_true.CloneToUnique();
+ evaluated_[tuple_select] = on_true.Clone();
} else {
- evaluated_[tuple_select] = on_false.CloneToUnique();
+ evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
- auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+ auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
@@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
- TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
- *cond_comp, {lcv.get()}));
- keep_going = cond_val->GetFirstElement<bool>();
+ TF_ASSIGN_OR_RETURN(auto cond_val,
+ cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+ keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
- *body_comp, {lcv.get()}));
- VLOG(3) << "Loop iteration result: " << body_val->ToString();
+ *body_comp, {&lcv}));
+ VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1173,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
- auto result_values_literal =
- absl::make_unique<Literal>(values_literal.shape());
- result_values_literal->PopulateR1(
+ Literal result_keys_literal(keys_literal.shape());
+ result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal result_values_literal(values_literal.shape());
+ result_values_literal.PopulateR1(
absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
- std::unique_ptr<Literal> result_tuple;
+ Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple = LiteralUtil::MakeTuple(
- {result_pair.first.get(), result_pair.second.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
- auto values_result_literal =
- absl::make_unique<Literal>(values_literal.shape());
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
+ .Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ .Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first->Reshape({1, r1_length}));
+ r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
- *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
- *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ r1_result_pair.second.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
- result_tuple = LiteralUtil::MakeTuple(
- {keys_result_literal.get(), values_result_literal.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
- VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1242,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
- const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) {
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+ // Out of convenience the literal may have been produced with a different
+ // layout. Relayout as indicated by the HLO instruction.
+ if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+ hlo->shape())) {
+ evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+ }
return Status::OK();
}
// Explicit instantiation of templatized Evaluate* methods.
//
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloModule& module, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
- const Literal*>(const HloComputation& computation,
- absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- HloInstruction* instruction,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 72252bafc7..21e676d671 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloModule& module,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloComputation& computation,
- absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
- StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction);
- // Same as Evaluate, except returning nullptr on error.
- std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+ // Same as Evaluate, except returning false on error and accepts an output
+ // pointer.
+ bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
- StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+ StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
- HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+ StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
- HloOpcode opcode, const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+ const Literal& operand);
- StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfig& precision_config, const Literal& lhs,
- const Literal& rhs);
+ StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
- return *(it->second);
+ return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
- tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
- evaluated_;
+ // Storing Literal in place require the container to have pointer stability so
+ // we cannot use FlatMap any more.
+ std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
- static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 102ebb24ab..01e88566a5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
evaluator_ = absl::make_unique<HloEvaluator>();
}
- std::unique_ptr<Literal> Evaluate(
- absl::Span<const Literal* const> arg_literals = {}) {
+ Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<HloEvaluator> evaluator_;
- void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> input, float aabs = 0) {
+ void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+ float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
- b.AddInstruction(
- HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+ b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto element_type = expected->shape().element_type();
+ auto element_type = expected.shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
}
- void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> lhs,
- std::unique_ptr<Literal> rhs) {
+ void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+ Literal rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(
- HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+ HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
bool use_bfloat16_;
@@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- Shape shape = low->shape();
+ Shape shape = low.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = LiteralUtil::CreateR0<float>(1.f);
- Shape shape = value->shape();
+ Shape shape = value.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- Shape shape = on_true->shape();
+ Shape shape = on_true.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
auto c2 =
@@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+ std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
lhs_instruction, param_rhs2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate(args);
+ Literal result = Evaluate(args);
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
@@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
- result->EachCell<NativeT>(
- [&](absl::Span<const int64> indices, NativeT value) {
- std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
- });
+ result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+ });
}
// Verifies Broadcast operation is correctly evaluated.
@@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction, {1, 2}));
+ output_literal.shape(), literal_instruction, {1, 2}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloInstruction::CreateConstant(std::move(input_literal)));
// Broadcast dimension should be empty in the case of scalars.
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction,
+ output_literal.shape(), literal_instruction,
/*broadcast_dimensions=*/{}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int64>(
{{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<int64>({100, 200});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
- ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
- ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
PaddingConfig CreatePaddingConfig(
@@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
shape, operand_instruction, padding_value_instruction, padding_config));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
@@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
@@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected_array = Array2D<float>({
@@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = Array2D<float>({
{22.f, 28.f},
@@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest,
@@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
@@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
std::iota(input_elems.begin(), input_elems.end(), -7);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
std::iota(filter_elems.begin(), filter_elems.end(), -31);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
@@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
/*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 1, 8);
expected_array.FillWithYX(
Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
module().AddEntryComputation(b.Build());
HloEvaluator hlo_eval;
- std::unique_ptr<Literal> result =
- hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+ Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
}
// Reducing many numbers should be fast because it doesn't create
@@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({6, 18});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
@@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
/*strides=*/{2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
shape, operand, update, start_indices));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto result_inner_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
- auto expected = LiteralUtil::MakeTuple({
- result_inner_literal.get(),
- result_inner_literal.get(),
- });
+ auto expected =
+ LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
+ Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, &param0_literal}, {square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
- auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+ auto result =
+ evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1906,12 +1901,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1930,12 +1925,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1954,14 +1949,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>(
+ LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1980,15 +1974,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest,
@@ -2008,15 +2001,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -2035,12 +2027,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -2059,13 +2050,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2084,11 +2074,10 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2108,12 +2097,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> start_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2138,15 +2127,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2171,15 +2158,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2205,15 +2191,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2239,15 +2223,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2273,17 +2255,15 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(
+ LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
- ErrorSpec{0.1, 0.01}));
+ Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2309,15 +2289,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2343,15 +2321,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2376,21 +2353,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
{{-40, 40}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest,
@@ -2416,21 +2390,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2455,16 +2426,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2489,17 +2458,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2524,13 +2490,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *operand,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ operand, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2557,16 +2521,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<int32>({10, 61, 32});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2603,11 +2564,29 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+ Literal arg = LiteralUtil::CreateR1<bfloat16>(
{bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+ Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
+}
+
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+ // Regression test for b/114735354.
+ const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+ arg = f32[2,2,2]{0,1,2} parameter(0)
+ ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ LayoutUtil::MakeLayout({0, 1, 2}));
+ Literal actual = Evaluate({&arg});
+ EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
}
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 63303aef1e..8fb17a0033 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = absl::make_unique<Literal>(dot->shape());
+ Literal result(dot->shape());
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = absl::make_unique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
- results[i] = absl::make_unique<Literal>(result_shape);
+ results[i] = Literal(result_shape);
}
Status eval_status;
@@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
- TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
- absl::InlinedVector<std::unique_ptr<Literal>, 1>
- embedded_operands;
+ absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
- [](const std::unique_ptr<Literal>& ptr) {
- return ptr.get();
- });
+ [](Literal& literal) { return &literal; });
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
- result_values[0] = computed_result->Get<ReturnT>({});
+ result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
- result_values[i] = computed_result->Get<ReturnT>(
+ result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
- auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
- TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
- result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 0345a2a5f8..287ba84b3b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -123,6 +123,10 @@ class NodeFilter {
// We arbitrarily set this as the boundary between "large" and "small"
// instructions.
bool IsSmall(const HloInstruction* instr) {
+ if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
+ ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
+ return true;
+ }
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
}
@@ -465,9 +469,8 @@ stylesheet=<
string graph_label =
StrCat(label_, "<br/>Computation ", computation_->name());
if (computation_->IsFusionComputation()) {
- StrAppend(&graph_label,
- StrCat(" (in fusion instruction ",
- computation_->FusionInstruction()->name(), ")"));
+ StrAppend(&graph_label, " (in fusion instruction ",
+ computation_->FusionInstruction()->name(), ")");
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 25ae344ea5..e905f2983a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
- instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -527,7 +528,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
- std::unique_ptr<Literal> literal) {
+ Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
@@ -2096,7 +2097,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
- if (!control_predecessors_.empty()) {
+ if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
StrJoin(control_predecessors_, ", ",
[&](string* out, HloInstruction* pre) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5581c17c2d..4f6cac1396 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -82,6 +82,7 @@ class HloPrintOptions {
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
+ print_control_dependencies_(true),
canonicalize_instruction_names_(false),
indent_amount_(0),
is_in_nested_computation_(false) {}
@@ -94,7 +95,8 @@ class HloPrintOptions {
.set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
- .set_print_percent(false);
+ .set_print_percent(false)
+ .set_print_control_dependencies(false);
}
// Options to produce the canonical string representing an isomorphic
@@ -108,6 +110,7 @@ class HloPrintOptions {
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
+ .set_print_control_dependencies(false)
.set_canonicalize_instruction_names(true);
}
@@ -153,6 +156,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, control dependencies will be printed.
+ HloPrintOptions& set_print_control_dependencies(bool value) {
+ print_control_dependencies_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
@@ -190,6 +199,9 @@ class HloPrintOptions {
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
+ bool print_control_dependencies() const {
+ return print_control_dependencies_;
+ }
bool canonicalize_instruction_names() const {
return canonicalize_instruction_names_;
}
@@ -205,6 +217,7 @@ class HloPrintOptions {
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
+ bool print_control_dependencies_;
bool canonicalize_instruction_names_;
int indent_amount_;
bool is_in_nested_computation_;
@@ -346,8 +359,7 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
- static std::unique_ptr<HloInstruction> CreateConstant(
- std::unique_ptr<Literal> literal);
+ static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index fb7345a2ad..e92882c22a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
- : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+ : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- if (literal_ != nullptr) {
+ if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
@@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
+ *literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
+ CHECK(literal_.has_value());
+ return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if (literal_ != nullptr &&
+ if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
@@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ *proto.mutable_literal() = literal_.ToProto();
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index c3a7801164..2d7bc83855 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
- explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
- bool HasLiteral() const { return literal_ != nullptr; }
+ bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
- string TracingTag() const { return literal_->GetR1U8AsString(); }
+ string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 9bfb0af96c..c7ec88d450 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <map>
#include <queue>
@@ -582,4 +582,22 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
size_function, nullptr, empty_map);
}
+HloMemoryScheduler::HloMemoryScheduler(
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm)
+ : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, size_function_, algorithm_));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ return true;
+}
+
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+ bool changed = module->has_schedule();
+ module->clear_schedule();
+ return changed;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 54e32340ba..5e02868eba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -86,6 +87,37 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+ // size_function is the function returning the number of bytes required for a
+ // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+ // specified, then DefaultMemoryScheduler is used.
+ HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
+ ~HloMemoryScheduler() override = default;
+ absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ LogicalBuffer::SizeFunction size_function_;
+ MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+ HloDescheduler() = default;
+ ~HloDescheduler() override = default;
+ absl::string_view name() const override { return "hlo-descheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 6afe51997e..1b9e9bfc77 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <memory>
#include <string>
@@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- HloSchedule schedule,
- ScheduleModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
+ HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ });
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
- schedule.sequence(module->entry_computation()).instructions();
+ module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
EXPECT_EQ(param, sequence.front());
EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(schedule);
+ SequentialHloOrdering ordering(module->schedule());
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+ // Clear the schedule using the descheduling pass.
+ HloDescheduler descheduler;
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+ descheduler.Run(module.get()));
+ EXPECT_TRUE(descheduler_changed);
+ EXPECT_FALSE(module->has_schedule());
}
TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 26fd1b2438..3bc2d13781 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -253,7 +253,7 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names);
+ bool uniquify_identifiers);
const string name_;
HloModuleConfig config_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index 98d20315e3..f7be5cae22 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -36,23 +36,6 @@ namespace xla {
namespace {
-bool HasSendRecv(HloComputation* computation) {
- for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kSendDone ||
- instruction->opcode() == HloOpcode::kRecv ||
- instruction->opcode() == HloOpcode::kRecvDone) {
- return true;
- }
- for (auto* sub_computation : instruction->called_computations()) {
- if (HasSendRecv(sub_computation)) {
- return true;
- }
- }
- }
- return false;
-}
-
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
for (auto* computation : module->computations()) {
@@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
while_body_root->opcode() != HloOpcode::kTuple ||
- HasSendRecv(while_body_comp)) {
+ while_body_comp->HasSideEffect() ||
+ xla_while->while_condition()->HasSideEffect()) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
- // with no send/recv instructions.
+ // with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e490..bf66cc6bc3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
+// Tests that if a loop variable is not referenced outside of a kWhile, the loop
+// variable changes are not elided within the loop body, if the condition
+// computation uses them.
+TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
+ auto module = ParseHloString(R"(
+ HloModule InfiniteLoop
+ WhileBody {
+ body_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
+ }
+ WhileCondition {
+ cond_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ p0 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
+ while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc
new file mode 100644
index 0000000000..f9b56ef464
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+namespace xla {
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ std::unique_ptr<HloModule> module)
+ : name_(name) {
+ push_back(std::move(module));
+}
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules)
+ : name_(name) {
+ for (auto& module : modules) {
+ push_back(std::move(module));
+ }
+}
+
+std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() {
+ std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_);
+
+ // Clear everything so the object state is in a known (empty) state.
+ modules_.clear();
+ module_ptrs_.clear();
+ return ret_modules;
+}
+
+string HloModuleGroup::ToString() const {
+ std::ostringstream s;
+ s << "HloModuleGroup " << name() << "\n\n";
+ for (const HloModule* module : modules()) {
+ s << module->ToString() << "\n";
+ }
+ return s.str();
+}
+
+HloModuleGroupProto HloModuleGroup::ToProto() const {
+ HloModuleGroupProto proto;
+ proto.set_name(name());
+ for (const HloModule* module : modules()) {
+ *proto.add_hlo_modules() = module->ToProto();
+ }
+ return proto;
+}
+
+/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs) {
+ TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty";
+ TF_RET_CHECK(proto.hlo_modules_size() > 0)
+ << "Module group must have at least one HLO module";
+ TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size());
+
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (int i = 0; i < proto.hlo_modules_size(); ++i) {
+ const HloModuleProto& module_proto = proto.hlo_modules(i);
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(module_proto, module_configs[i]));
+ modules.push_back(std::move(module));
+ }
+
+ return HloModuleGroup(proto.name(), absl::MakeSpan(modules));
+}
+
+void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
+ modules_.push_back(std::move(module));
+ module_ptrs_.push_back(modules_.back().get());
+}
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) {
+ out << group.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h
new file mode 100644
index 0000000000..7338be8b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+
+// An abstraction representing a ordered set of HLO module built to run
+// concurrently across different devices.
+class HloModuleGroup {
+ public:
+ // Construct an empty module group.
+ explicit HloModuleGroup(absl::string_view name) : name_(name) {}
+
+ // Construct a module group containing a single module.
+ HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
+
+ // Construct a module group containing any number of modules.
+ HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules);
+
+ // Returns the modules contained in the group.
+ const std::vector<HloModule*>& modules() const { return module_ptrs_; }
+
+ // Returns a module at a particular index.
+ HloModule& module(int index) const { return *module_ptrs_.at(index); }
+
+ // Add a module to the back of vector of modules in the group.
+ void push_back(std::unique_ptr<HloModule> module);
+
+ // Moves all modules from the group into the returned vector. After this
+ // method runs, the module group will be empty.
+ std::vector<std::unique_ptr<HloModule>> ConsumeModules();
+
+ string name() const { return name_; }
+ string ToString() const;
+
+ // Serialize the module group to/from a proto.
+ HloModuleGroupProto ToProto() const;
+ static StatusOr<HloModuleGroup> CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs);
+
+ private:
+ string name_;
+
+ // Vector of modules as std::unique_ptrs.
+ std::vector<std::unique_ptr<HloModule>> modules_;
+
+ // Vector of modules as normal pointers. This vector is kept in sync with
+ // modules_ as modules are added to the group with push_back.
+ std::vector<HloModule*> module_ptrs_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
new file mode 100644
index 0000000000..ebf790ba6f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class HloModuleGroupTest : public HloTestBase {
+ protected:
+ HloModuleGroupTest() = default;
+};
+
+TEST_F(HloModuleGroupTest, SingleModule) {
+ const string text = R"(
+HloModule simple_module
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ HloModuleGroup group(TestName(), std::move(module));
+
+ EXPECT_EQ(group.modules().size(), 1);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 1);
+ EXPECT_THAT(
+ group_copy.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
+ EXPECT_EQ(modules.size(), 1);
+ EXPECT_EQ(group.modules().size(), 0);
+}
+
+TEST_F(HloModuleGroupTest, MultipleModules) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ std::vector<std::unique_ptr<HloModule>> modules;
+ modules.push_back(std::move(module_0));
+ modules.push_back(std::move(module_1));
+ HloModuleGroup group(TestName(), absl::MakeSpan(modules));
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config(),
+ group.module(1).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 2);
+}
+
+TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ HloModuleGroup group(TestName());
+ group.push_back(std::move(module_0));
+ group.push_back(std::move(module_1));
+
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 400bd4d947..39f38b417a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+ // Verify that serializing then deserializing an HLO proto preserves the
+ // unique IDs of the instruction and module.
+ const string text =
+ R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+ input = f32[8,16,256]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+
+ // Perform various transformations on the graph:
+ //
+ // * clone the reduction function
+ // * replace use of reduction function with the clone.
+ // * add a random instruction to the entry computation.
+ //
+ // This will create instruction and computation IDs which are interesting:
+ // not consecutive and not densely packed.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+ HloComputation* reduction = root->to_apply();
+ HloComputation* reduction_clone =
+ module->AddEmbeddedComputation(reduction->Clone());
+ root->set_to_apply(reduction_clone);
+ TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+ HloInstruction* negate = entry->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+ entry->set_root_instruction(negate);
+
+ // Schedule the transformed module, this verifies that the serialized schedule
+ // is robust against non-consecutive IDs as well (b/114712358).
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ HloMemoryScheduler scheduler(size_fn);
+ TF_ASSERT_OK(scheduler.Run(module.get()).status());
+ ASSERT_TRUE(module->has_schedule());
+
+ // Serialize and deserialize and verify that the instruction and computations
+ // unique ids are the same.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+ // The module IDs should *not* be the same because module ids must be globally
+ // unique.
+ EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+ // Verify that the computations and instructions all have the same unique id.
+ auto computation_copy_it = module_copy->computations().begin();
+ for (const HloComputation* computation_orig : module->computations()) {
+ const HloComputation* computation_copy = *computation_copy_it++;
+ EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original computation %s != ID of deserialized "
+ "computation %s: %d != %d",
+ computation_orig->name(), computation_copy->name(),
+ computation_orig->unique_id(), computation_copy->unique_id());
+
+ auto instruction_copy_it = computation_copy->instructions().begin();
+ for (const HloInstruction* instruction_orig :
+ computation_orig->instructions()) {
+ const HloInstruction* instruction_copy = *instruction_copy_it++;
+ EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original instruction %s != ID of deserialized "
+ "instruction %s: %d != %d",
+ instruction_orig->name(), instruction_copy->name(),
+ instruction_orig->unique_id(), instruction_copy->unique_id());
+ }
+ }
+
+ // Verify that the next unique ID which the module would have handed out is
+ // greater than the unique id of any instruction.
+ int next_id = module_copy->NewUniqueInstructionId();
+ for (const HloComputation* computation : module_copy->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_GT(next_id, instruction->unique_id());
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 6b6005e7a5..00970bcda3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c54360b063..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = absl::make_unique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c95972b..d9848cee0b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 0a0a6a323e..bd6dd79b67 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -27,15 +27,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -1194,51 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(HloModule* module,
- HloSchedule* schedule,
- int64 memory_limit_bytes,
- RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The schedule is constructed entirely by this method.
- TF_RET_CHECK(schedule->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
VLOG(1) << "HloRematerialization() with memory limit of "
- << HumanReadableNumBytes(memory_limit_bytes);
+ << HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial schedule of HLO instructions.
- TF_ASSIGN_OR_RETURN(*schedule,
- ScheduleModule(*module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(*schedule);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
-
- // RemoveUnnecessaryCopies only considers interference when determining
- // whether it is legal to remove a copy. However, copies in the graph may be
- // necessary for other reason such as preventing a constant from being live
- // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
- // TODO(b/80249101): Break copy insertion into several passes and run each
- // one once in the regular HLO pipeline.
- TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
-
- // The passes above can add and remove copies, update the schedule to
- // account for these transformations. Newly added instructions will be
- // placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(schedule->Update());
-
- TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(*schedule), module));
- }
-
+ TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1254,7 +1214,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
});
const int64 adjusted_memory_limit_bytes =
- memory_limit_bytes - module_output_size;
+ memory_limit_bytes_ - module_output_size;
VLOG(1) << "Adjusted memory limit accounting for output ("
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
@@ -1263,13 +1223,14 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, schedule](const CallGraphNode& node) -> Status {
+ [this, module](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(
- node.computation(),
- schedule->sequence(node.computation()).instructions()));
+ ComputePeakMemory(node.computation(),
+ module->schedule()
+ .sequence(node.computation())
+ .instructions()));
}
return Status::OK();
},
@@ -1287,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), schedule,
- adjusted_memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(
+ bool changed,
+ RematerializeComputation(module->entry_computation(), &module->schedule(),
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1298,7 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// After DCE, the module sequence may include instructions which no longer
// exist.
- TF_RETURN_IF_ERROR(schedule->Update());
+ TF_RETURN_IF_ERROR(module->schedule().Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1315,32 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
- if (sizes != nullptr) {
- sizes->before_bytes = before_peak_memory;
- sizes->after_bytes = current_peak_memory;
+ if (sizes_ != nullptr) {
+ sizes_->before_bytes = before_peak_memory;
+ sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
- if (current_peak_memory > memory_limit_bytes) {
+ if (current_peak_memory > memory_limit_bytes_) {
LOG(WARNING) << absl::StrFormat(
"Can't reduce memory use below %s (%d bytes) by rematerialization; "
"only reduced to %s (%d bytes)",
- HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+ HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
- const HloRematerialization::ShapeSizeFunction& size_function,
- int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
- copy_insertion);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index fa0414b472..e2aaf18b3e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,17 +17,23 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
namespace xla {
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -38,10 +44,7 @@ class HloRematerialization {
int64 after_bytes;
};
- // Rematerialize HLO instructions in the given module to reduce peak memory
- // use below memory_limit_bytes where memory use is defined as the total size
- // of all live HLO instruction values. Parameters and constants are included
- // in memory use estimates. Method parameters:
+ // Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -49,51 +52,27 @@ class HloRematerialization {
// memory_limit_bytes: The threshold number of bytes to reduce memory use to
// via rematerialization.
//
- // hlo_module: HLO module to rematerialize instructions in.
- //
- // schedule: Should point to an empty HloSchedule. Upon return
- // contains the HLO instruction order which was used for
- // rematerialization. This is the order in which HLO instructions should
- // be emitted to minimize memory use.
- //
- // sizes: Optional outparam that indicates the peak memory usage of the HLO
- // module before/after rematerialization.
- //
- // copy_insertion: If non-null, run copy elision after scheduling. This
- // pass is used to eliminate copies that were inserted by copy insertion
- // before HLO scheduling.
- //
- // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
- // insertion is integrated with HLO scheduling.
- //
- // Returns whether any instructions were rematerialized. If memory use is
- // already below the given limit then no instructions are rematerialized and
- // false is returned.
- //
- // CSE will undo the effects of this optimization and should not be run after
- // this pass. In general, this pass should be run very late immediately before
- // code generation.
- static StatusOr<bool> RematerializeAndSchedule(
- const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- HloSchedule* schedule, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion = nullptr);
-
- protected:
- HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function)
- : scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function) {}
+ // sizes: Pointer to data structure which records the peak memory usage of
+ // the HLO module before/after rematerialization. Value are set during
+ // Run(). Can be nullptr.
+ HloRematerialization(const ShapeSizeFunction& size_function,
+ int64 memory_limit_bytes, RematerializationSizes* sizes)
+ : size_function_(size_function),
+ memory_limit_bytes_(memory_limit_bytes),
+ sizes_(sizes) {}
~HloRematerialization() {}
+ absl::string_view name() const override { return "rematerialization"; }
+
// Runs rematerialization on the given module. Returns whether the module was
- // changed. memory_limit is the target maximum peak memory usage by the
- // module. schedule should be an empty HloSchedule. Upon return sequence
- // contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
- int64 memory_limit, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion);
+ // changed. Requires that the module has a schedule set
+ // (HloModule::has_schedule() is true) before running. Returns whether any
+ // instructions were rematerialized. If memory use is already below the limit
+ // specified in the constructor then no instructions are rematerialized and
+ // false is returned.
+ StatusOr<bool> Run(HloModule* module) override;
+ protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
@@ -121,6 +100,14 @@ class HloRematerialization {
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
+ // The threshold number of bytes to reduce memory use to via
+ // rematerialization.
+ const int64 memory_limit_bytes_;
+
+ // Pointer to data structure which records the peak memory usage of the HLO
+ // module before/after rematerialization
+ RematerializationSizes* sizes_;
+
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83cb113bfb..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
- HloModule* module,
- HloSchedule* schedule) {
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- schedule, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &schedule));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 66ac1f66fd..fa7f216321 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals) {
+ const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
- literal_pointers.push_back(literal.get());
+ literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
@@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
- argument_pointers.push_back(argument.get());
+ argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = absl::make_unique<Literal>();
+ Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, options.outfeed_shape, literal.get()));
+ executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
- std::vector<std::unique_ptr<Literal>> exec_results;
+ std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 76d8b92bed..2e934bf66a 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
- std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+ std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@@ -106,24 +106,23 @@ class HloRunner {
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals);
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- const ShapedBuffer& buffer);
+ const absl::Span<const Literal> literals);
+ StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const Literal* const> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal* const> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@@ -140,7 +139,7 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
- StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
index eb52582bb5..1424569ac1 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a1f2..6fd734a2b9 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 069586a738..50f39cbcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1123,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+ // If the module has a schedule, it must be valid.
+ if (module->has_schedule()) {
+ TF_RETURN_IF_ERROR(module->schedule().Verify());
+ }
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0cac210c24..8f0423bb1c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateNewModule();
@@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 37b774b8a5..06f0e1ed25 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> inner_broadcast_result,
+ Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 9746d176cc..df9cbab915 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
- Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
- StatusOr<Literal*> TakeOwnership(
- StatusOr<std::unique_ptr<Literal>> literal_or_error) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- std::move(literal_or_error));
+ StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
- std::vector<std::unique_ptr<Literal>> owned_literals_;
+ std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 5695bc2420..7e967f035c 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloTestBase;
+using InlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
@@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907eae0c..3fdc2cee9a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible(
return do_not_duplicate;
}
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+ explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+ post_order_ = computation->MakeInstructionPostOrder();
+
+ for (size_t i = 0; i < post_order_.size(); ++i) {
+ InsertOrDie(&post_order_index_, post_order_[i], i);
+ }
+ }
+
+ std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ // Instructions are "removed" from the post order by nulling out the element
+ // in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ while (!post_order_.empty() && post_order_.back() == nullptr) {
+ post_order_.pop_back();
+ }
+ if (post_order_.empty()) {
+ return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+ }
+ // We want to iterate in reverse post order, so remove from the back of the
+ // vector.
+ HloInstruction* instruction = post_order_.back();
+ post_order_.pop_back();
+
+ CHECK(instruction != nullptr);
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index_.erase(instruction);
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid creating duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the order
+ // they appear int the queue. In the example, this ensures that B will be
+ // considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the queue.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher priority in the queue come first.
+ return (
+ FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+ });
+ return std::make_pair(instruction, sorted_operand_numbers);
+ }
+
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) override {
+ // Fusing an instruction into a fusion instruction can change the operand
+ // set of the fusion instruction. For simplicity just re-enqueue the
+ // instruction and reconsider it for further fusion in the next iteration.
+ InsertOrDie(&post_order_index_, fusion, post_order_.size());
+ post_order_.push_back(fusion);
+ }
+
+ void RemoveInstruction(HloInstruction* instruction) override {
+ post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+ post_order_index_.erase(instruction);
+ }
+
+ private:
+ std::vector<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+} // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer) {
+ return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(2) << "Before instruction fusion:";
XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
computation_ = computation;
reachability_ = computation_->ComputeReachability();
- // We want to be able to remove arbitrary instructions from the post order
- // and also compare positions of instructions in the post order. To make
- // this possible, create vector of instructions in post order and create a
- // map from HloInstruction* to the instruction's index in the vector. An
- // instruction is "removed" from the vector by setting it's element to
- // nullptr.
- std::vector<HloInstruction*> post_order =
- computation_->MakeInstructionPostOrder();
-
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
- for (size_t i = 0; i < post_order.size(); ++i) {
- InsertOrDie(&post_order_index, post_order[i], i);
- }
-
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+ HloInstructionSet do_not_duplicate =
+ ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+ auto fusion_queue =
+ GetFusionQueue(computation_, [&](HloInstruction* producer) {
+ return do_not_duplicate.count(producer) > 0;
+ });
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
// edges. When we fuse an edge, we create a copy of the producer inside the
// fusion instruction.
- while (!post_order.empty()) {
- // We want to iterate in reverse post order, so remove from the back of
- // the vector.
- HloInstruction* instruction = post_order.back();
- post_order.pop_back();
-
- // Instructions are "removed" from the post order by nulling out the
- // element in the vector, so if the pointer is null, continue to the next
- // instruction in the sort.
+ while (true) {
+ auto next_entry =
+ fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+ auto instruction = next_entry.first;
if (instruction == nullptr) {
- continue;
+ break;
}
- // Remove instruction from the index map to ensure the vector and map stay
- // consistent.
- post_order_index.erase(instruction);
-
if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- // Consider each operand of this instruction for fusion into this
- // instruction. We want to consider the operands in a particular order to
- // avoid creating duplicate instruction clones in the fusion instruction.
- // For example, consider the following expression:
- //
- // A = ...
- // B = op(A)
- // C = op(A, B)
- //
- // If we are considering the operands of C for fusion into C. We might
- // fuse A or B first. If we fuse A first, we get:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // C' = op(A', B) }
- //
- // Where A' and C' are clones of A and C, respectively. Now only B is an
- // operand of the fusion instruction C_fusion, so then we fuse B:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // B' = op(A)
- // C' = op(A', B') }
- //
- // Now A is an operand of C_fusion again, so we then fuse A (again!):
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // A" = ..
- // B' = op(A")
- // C' = op(A', B') }
- //
- // We prevent this duplication by considering the operands in the reverse
- // order they appear in the instruction post order. In the example, this
- // ensures that B will be considered before A.
- //
- // We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers;
- sorted_operand_numbers.reserve(instruction->operands().size());
- for (int i = 0; i < instruction->operands().size(); ++i) {
- // This will happen if we have two possible instructions to fuse the
- // same operand into; once the operand is fused into one instruction,
- // the other instruction will get a new get-tuple-element as its
- // operand, which is not in the post-order index.
- // TODO(tjoerg): Look into fusing past these multi-output fuse points.
- if (post_order_index.find(instruction->mutable_operand(i)) ==
- post_order_index.end()) {
- continue;
- }
- sorted_operand_numbers.push_back(i);
- }
- std::sort(
- sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
- [&](int64 i, int64 j) {
- // Instructions with higher indices in the post order come
- // first.
- return (
- FindOrDie(post_order_index, instruction->mutable_operand(i)) >
- FindOrDie(post_order_index, instruction->mutable_operand(j)));
- });
+ std::vector<int64>& sorted_operand_numbers = next_entry.second;
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// TODO(tjoerg): Consider making multi-output fusion the default.
if (ShouldFuse(instruction, i) &&
do_not_duplicate.count(operand) == 0) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = Fuse(operand, instruction);
} else if (ShouldFuseIntoMultiOutput(instruction, i) &&
!MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = FuseIntoMultiOutput(operand, instruction);
} else {
continue;
}
- // Fusing an instruction into a fusion instruction can change the
- // operand set of the fusion instruction. For simplicity just push the
- // instruction to the top of the post_order and reconsider it for
- // further fusion in the next iteration of the outer loop.
- post_order.push_back(fusion_instruction);
- InsertOrDie(&post_order_index, fusion_instruction,
- post_order.size() - 1);
+ fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+ instruction);
changed = true;
if (operand->user_count() == 0) {
- // Operand is now dead. Remove from post order by setting its
- // location to nullptr.
- post_order[FindOrDie(post_order_index, operand)] = nullptr;
- post_order_index.erase(operand);
-
+ do_not_duplicate.erase(operand);
+ // Operand is now dead. Remove from queue.
+ fusion_queue->RemoveInstruction(operand);
// Remove from computation.
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
}
+
+ if (fusion_instruction != instruction) {
+ do_not_duplicate.erase(instruction);
+ }
break;
}
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 00b658959a..c1fde8ecfc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,6 +24,33 @@ limitations under the License.
namespace xla {
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
@@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface {
static bool IsExpensive(const HloInstruction& instruction);
protected:
+ // Returns a FusionQueue that implements custom order of instructions being
+ // fused. The default implementation processes consumers in reverse post
+ // order.
+ virtual std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer);
+
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
// Derived classes should define this method to specify which instructions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 5dea124768..a06d6113e8 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
- std::unique_ptr<Literal> result_literal;
+ Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
- TF_ASSIGN_OR_RETURN(result_literal,
- evaluator_->Evaluate<std::unique_ptr<Literal>>(
- *computation, arg_literals));
+ TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+ *computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
- result_literal->shape(), run_options->allocator(),
+ result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- run_options->stream(), *result_literal, result));
+ run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 69c7e42601..752a61476d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout,
@@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(module.get())
+ .Run(module)
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto module = CreateNewModule();
@@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
- module =
+ std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
- ->RunBackend(std::move(module),
+ ->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
@@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(&module(), &computation_layout);
- EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
@@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
HloComputation* computation = module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(module.get()).status();
+ Status error_status = layout_assignment.Run(module).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
@@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
@@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
- AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+ AssignLayouts(&module(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::GetSubshape(
- FindInstruction(module.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
@@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest,
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index f0e2566a3f..b27a92f2a0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
- *module->add_arguments() = literal->ToProto();
+ *module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
- *module->mutable_result() = literal->ToProto();
+ *module->mutable_result() = literal.ToProto();
return Status::OK();
}
@@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> result_literal,
+ Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
- if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
- result_literal->shape())) {
- *result->mutable_literal() = result_literal->ToProto();
+ if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+ *result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
- result_literal->Relayout(*return_shape)->ToProto();
+ result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- stream.get(), *literal, shaped_buffer));
+ stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, *literal);
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+ literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), *literal));
- *result->mutable_literal() = literal->ToProto();
+ executor, arg->shape_with_layout(), literal));
+ *result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, /*arg_literals=*/{}));
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+ *module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
+ result_literal = result_literal.Relayout(arg->output_layout());
}
- *result->mutable_literal() = result_literal->ToProto();
+ *result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}
@@ -1162,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
return replicas;
}
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
const string xla_dump_unoptimized_hlo_proto_to =
module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1170,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const {
}
HloProto proto = MakeHloProto(module);
return protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+ proto, xla_dump_unoptimized_hlo_proto_to,
+ StrCat(module.name(), ".unoptimized"));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248b15..1f62fad4c8 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@ class Service : public ServiceInterface {
StatusOr<std::vector<se::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
- Status MaybeDumpHloModule(const HloModule& module) const;
+ // Dumps the (unoptimized) module given if the corresponding DebugOptions
+ // field has been set.
+ Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
deleted file mode 100644
index dd53c7531b..0000000000
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/source_map_util.h"
-
-#include "absl/strings/str_format.h"
-#include "tensorflow/compiler/xla/util.h"
-
-namespace xla {
-namespace source_map_util {
-namespace {
-
-Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
- const char* format, va_list args) {
- string message;
- tensorflow::strings::Appendv(&message, format, args);
- if (!op_metadata.source_file().empty()) {
- absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
- op_metadata.source_line());
- }
- return InvalidArgument("%s", message);
-}
-
-} // namespace
-
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- Status result = InvalidParameterArgumentV(op_metadata, format, args);
- va_end(args);
- return result;
-}
-
-Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- if (executable != nullptr && executable->has_module()) {
- const HloModule& module = executable->module();
- const HloComputation& computation = *module.entry_computation();
- HloInstruction* param = computation.parameter_instruction(parameter_number);
- const OpMetadata& metadata = param->metadata();
- Status result = InvalidParameterArgumentV(metadata, format, args);
- va_end(args);
- return result;
- }
- Status result = InvalidArgumentV(format, args);
- va_end(args);
- return result;
-}
-
-} // namespace source_map_util
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index b8d2d546e5..a21e586efa 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 21725946b3..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
- virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source);
+ StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+ const Shape& shape,
+ const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2b2a2eb42a..e9a07b14ed 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})};
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b693872d..516754e211 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;
@@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
EXPECT_THAT(computation->root_instruction(), gte);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param1);
}
@@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
EXPECT_THAT(computation->root_instruction(),
op::Negate(op::GetTupleElement(op::Tuple())));
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
}
@@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
EXPECT_THAT(computation->root_instruction(), element);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), tuple_param);
}
@@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
EXPECT_THAT(computation->root_instruction(), tuple);
}
@@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
entry = module->AddEntryComputation(builder.Build());
}
- Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+ Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603c7e..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 52c895e8d4..df610102b4 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -224,14 +224,13 @@ class ShapeTree {
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
- return iterator(&nodes_, typename std::vector<Node>::iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.begin() + (element - &nodes_[0]);
+ return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
- return iterator(&nodes_,
- typename std::vector<Node>::const_iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
+ return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
// Returns the number of leaf nodes in the tree.
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9772c06bce..96c80fd577 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return count;
}
+/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type) {
+ if (shape.element_type() == primitive_type) {
+ return true;
+ }
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ if (HasPrimitiveType(element_shape, primitive_type)) {
+ return true;
+ }
+ }
+ return false;
+}
+
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 8234fcdd3f..623ae39de8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -180,6 +180,10 @@ class ShapeUtil {
// As ElementsIn(), but recurses through tuples.
static int64 ElementsInRecursive(const Shape& shape);
+ // Returns true if shape has the primitive type, recurses through tuples.
+ static bool HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type);
+
// Returns true if 'shape' is an array with zero elements.
static bool IsZeroElementArray(const Shape& shape);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 6ca4085aaf..c622ecdca1 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
+TEST(ShapeUtilTest, HasPrimitiveType) {
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
+ S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
+ S16));
+}
+
TEST(ShapeUtilTest, IsZeroElementArray) {
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index d0bda45cf8..30e3077edb 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -647,6 +647,7 @@ xla_test(
],
shard_count = 48,
tags = [
+ "broken",
"manual",
"notap",
],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 0bf4556b43..c257566fb2 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -41,7 +41,6 @@ limitations under the License.
namespace xla {
namespace {
-
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001, 0.0001};
@@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<uint64> rhs{1,
0x7FFFFFFFFFFFFFFLL,
@@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Add(lhs_param, rhs_param);
@@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<int64> rhs{-1,
0,
@@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Sub(lhs_param, rhs_param);
@@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
XlaBuilder b(TestName());
std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
Lt(lhs_param, rhs_param);
- ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
+ ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
}
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
@@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
- auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
+ auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
auto b_param = ConstantR1<float>(&builder, b_values);
auto sum1 = Add(a_constant, b_constant);
@@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ Literal param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto sum = ConstantR0<float>(&b, 0.0f);
- auto param = Parameter(&b, 0, param_literal->shape(), "param");
+ auto param = Parameter(&b, 0, param_literal.shape(), "param");
for (float exponent : exponents) {
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
@@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
@@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
@@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
@@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
@@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
@@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
@@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
@@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
+ Literal literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
- client_->TransferToServer(*literal3).ConsumeValueOrDie();
+ client_->TransferToServer(literal3).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
- auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
+ auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
@@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
Array3D<float> expected(0, 7, 0);
@@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
- auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
-0.79, 1.41, 1.21, 1.05});
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Tanh(input);
ComputeAndCompareR1<float>(
@@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Exp(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::exp(input_literal->Get<float>({i})));
+ expected_result.push_back(std::exp(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Log(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::log(input_literal->Get<float>({i})));
+ expected_result.push_back(std::log(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
- LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- auto a = ConstantLiteral(&builder, *a_literal);
+ Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ auto a = ConstantLiteral(&builder, a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
- auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
+ auto y = Parameter(&builder, 1, y_literal.shape(), "y");
auto slice = Slice(x, {1}, {2}, {1});
Sub(slice, y);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index ac90a3adb6..bc2ba151a3 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+ input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
- {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({0, 0}).get(),
- LiteralUtil::CreateR1<float>({16, 20}).get()});
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+ LiteralUtil::CreateR1<float>({0, 0}),
+ LiteralUtil::CreateR1<float>({16, 20})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
- auto expected = LiteralUtil::MakeTuple(
- {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
- LiteralUtil::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+ LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, *expected,
+ &builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
- Parameter(&builder, 4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
- auto input_parameter =
- Parameter(&builder, 0, input_literal->shape(), "input");
- auto scale_parameter =
- Parameter(&builder, 1, scale_literal->shape(), "scale");
- auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
- auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+ auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+ auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
- Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
- client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+ client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
- auto expected =
- LiteralUtil::MakeTuple({expected_grad_activation.get(),
- LiteralUtil::CreateR1<float>(grad_scale).get(),
- LiteralUtil::CreateR1<float>(grad_offset).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+ LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 65589b0d6a..e9728e636f 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
- {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
- .get(),
+ {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
- .get()});
+ {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
- {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
- .get(),
+ {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
- .get()});
+ {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index fe4267c73b..dde19fb65d 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
- client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+ client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
@@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
- client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+ client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
- {r3_implicit_global_data.get(), r3_global_data.get()},
+ &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+ ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
+ &builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 74d4d2eb10..9966e4606e 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
error_spec_));
}
@@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralSlice(*result, {0}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ LiteralSlice(result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralSlice(*result, {1}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ LiteralSlice(result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
error_spec_));
}
@@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
error_spec_));
}
@@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index b1d18210ea..8b31e53707 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant =
- ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+ auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
- auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
auto x =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
auto y =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
@@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
auto elem = LiteralUtil::CreateR0<float>(42.0);
- auto tuple = LiteralUtil::MakeTuple({elem.get()});
- Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
+ auto tuple = LiteralUtil::MakeTuple({&elem});
+ Call(&builder, callee, {ConstantLiteral(&builder, elem)});
- ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+ ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index a4eb57fc7b..2f1510ff69 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
- auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
Add(p0, p1);
auto param0_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto param1_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
@@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto computation = computation_status.ConsumeValueOrDie();
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
- auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
- client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
- auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+ auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
// Match
auto status = client_->Execute(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 8a236db0ff..fbdf0fcb65 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
@@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
- return result.ValueOrDie()->ToString();
+ return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
- verify_output(*actual, "");
+ verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual,
+ verify_output(actual,
absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
- if (ShapeUtil::IsTuple(literal->shape())) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal->shape()));
+ ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
- std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+ std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
- literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+ literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+ ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
- client_->TransferToServer(*literal_relayout));
+ client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
}
- verify_output(*actual, error_message);
+ verify_output(actual, error_message);
return Status::OK();
};
@@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
@@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
@@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+ Literal expected_literal = LiteralUtil::CreateR1U8(expected);
- VLOG(1) << "expected: " << expected_literal->ToString();
- VLOG(1) << "actual: " << actual->ToString();
+ VLOG(1) << "expected: " << expected_literal.ToString();
+ VLOG(1) << "actual: " << actual.ToString();
- EXPECT_EQ(expected, actual->GetR1U8AsString());
+ EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
@@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
- ? *LiteralUtil::ConvertF32ToBF16(literal)
- : literal);
+ ? LiteralUtil::ConvertF32ToBF16(literal)
+ : LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
- return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 22dfdfb0e4..9d32f4f517 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// This executes the computation via the reference client (which connects a
// interpreter backend). The result is used as the expected values of the
// computation.
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
+ StatusOr<Literal> ExecuteAndTransferReference(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
+ return AddParam(LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
// Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- absl::Span<const Literal> arguments);
+ StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
+ XlaBuilder* builder, absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@@ -412,9 +411,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -438,9 +435,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, absl::Span<const NativeT> expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -464,9 +459,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -490,9 +485,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -516,9 +511,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -542,13 +537,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR0(value);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -556,13 +551,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR1(values);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -570,13 +565,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -584,13 +579,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index c898dacf48..6f2ca84bb6 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
- auto computed, client_->Transfer(*data, &expected_literal->shape()));
+ auto computed, client_->Transfer(*data, &expected_literal.shape()));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
@@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralSlice(*result, {0}));
+ LiteralSlice(result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralSlice(*result, {1}));
+ LiteralSlice(result, {1}));
- EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
- EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+ EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::GetTupleElementShape(result.shape(), 0),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::GetTupleElementShape(result.shape(), 1),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})));
}
@@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
+ LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
- client_->Transfer(*results[0], &expected_result->shape()));
+ client_->Transfer(*results[0], &expected_result.shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 03d5696499..6ef7ca035f 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
absl::Span<GlobalData* const> arguments,
float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
- std::unique_ptr<Literal> result =
+ Literal result =
client_
->ExecuteAndTransfer(computation, arguments,
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
->Execute(computation, arguments,
&execution_options_, &execution_profile)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data_handle).ConsumeValueOrDie();
+ Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
- client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
- client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 8226b6de3f..3b0414a604 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
- StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const XlaOp& operand, XlaBuilder* builder,
- Layout* output_layout = nullptr) {
+ StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
- return literal->Get<Scalar>({});
+ return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<int32>({4, 6});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index be017477d8..9811a015e9 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
@@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
@@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 25d10ab00a..32cac499c7 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
- LiteralUtil::CreateR0<float>(25.0f).get()}),
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+ LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
- LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<bool>(true).get(),
- LiteralUtil::CreateR0<float>(12.2f).get(),
- LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(true),
+ LiteralUtil::CreateR0<float>(12.2f),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(46.6f).get(),
- LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
- .get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
- LiteralUtil::CreateR0<float>(9.3f).get()})
- .get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(46.6f),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+ LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
- LiteralUtil::CreateR0<float>(b).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is true case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is false case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4937574831..72ff1e74a4 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4D(input_array);
+ Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *input_literal);
+ ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
@@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})}));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
+ Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
- LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ LiteralSlice(result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+ ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 7a203d6873..5f063e6784 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
@@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 38b6da4fa9..fd98bf29b8 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index d2c6478b02..070b092d18 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
- auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
+ auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r5).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r5,
+ ComputeAndCompareLiteral(&builder, expected_r5,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(param0)),
- std::move(*LiteralUtil::CreateFromArray(param1))},
+ {LiteralUtil::CreateFromArray(param0),
+ LiteralUtil::CreateFromArray(param1)},
error_spec_);
}
@@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
- auto expected_r3 =
- expected_r1->Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
+ auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+ client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r3,
+ client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
@@ -891,9 +890,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillIota(10);
- ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))});
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)});
+}
+
+XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
+ XlaBuilder builder(TestName());
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
+ Array4D<float> input_data(1, 64, 100, 100);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
+ Array4D<float> filter_data(7, 7, 1, 64);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = ConstantR4FromArray4D(&builder, filter_data);
+
+ // Specify bf01_01io->bf01 as dimension numbers.
+ ConvolutionDimensionNumbers dnums;
+ // Input
+ dnums.set_input_feature_dimension(1);
+ dnums.set_input_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ // Kernel
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ // Output
+ dnums.set_output_batch_dimension(0);
+ dnums.set_output_feature_dimension(1);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(3);
+ ConvGeneral(input, filter, /*window_strides=*/{1, 1},
+ /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
+ /*feature_group_count=*/64);
+
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
+ error_spec_);
}
class ConvolutionHloTest : public HloTestBase {};
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 6784c16715..ba3e9c436e 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
- weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto weights = ConstantLiteral(&builder, *weights_literal);
+ weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto weights = ConstantLiteral(&builder, weights_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto mirrored_weights = Rev(weights, {2, 3, 4});
ConvWithGeneralPadding(gradients, mirrored_weights,
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
- activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
- auto activations = ConstantLiteral(&builder, *activations_literal);
+ activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
+ auto activations = ConstantLiteral(&builder, activations_literal);
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto forward_conv =
ConvGeneralDilated(activations, gradients,
@@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
Transpose(forward_conv, {0, 1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 526626c1dd..1407e68d9a 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
protected:
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ auto constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+ TestCopyOp(LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4(
+ TestCopyOp(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
// Copy literal to device to use as parameter.
auto literal = LiteralUtil::CreateR0<float>(42.0);
- Shape shape = literal->shape();
+ Shape shape = literal.shape();
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
@@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(module), {literal.get()});
- LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {&literal});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
- Layout* literal_layout =
- literal->mutable_shape_do_not_use()->mutable_layout();
+ Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
@@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
error_spec_);
}
@@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+ LiteralTestUtil::ExpectR3EqualArray3D(a, result);
}
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
@@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+ Literal literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+ LiteralTestUtil::ExpectR4EqualArray4D(a, result);
}
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
XlaBuilder builder(TestName());
Parameter(&builder, 0, in_shape, "input");
- auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
+ auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index d12a4e7fcd..410732c07b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
- EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+ EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(
- *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
}
// On the GPU backend, constants get special handling. Someone might pass a
@@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 6f7fc0e6e5..a693fa3595 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
@@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
module->AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
- Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index eb15fc0593..e0f23b0fa8 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
// Try copying the elements back and comparing it
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
auto handles1 = result_status1.ConsumeValueOrDie();
auto handles2 = result_status2.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
handles1[0].reset();
handles1[1].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
// the same as handle[3] and handle[1] should be the same as handle[2].
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
// should not have been deallocated because of reference counting.
global_data.reset();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
/// Try deallocating one of the repeated elements, then copy
handles[0].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 5873516442..0171f51583 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
+ LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
- param.dot_lhs_row_major)));
+ Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
- client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
- std::unique_ptr<Literal> dot_rhs_lit =
+ Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
- client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> addend_data;
- std::unique_ptr<Literal> addend_lit;
+ Literal addend_lit;
std::unique_ptr<GlobalData> addend_handle;
if (param.has_addend) {
@@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
- addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
+ addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
}
XlaBuilder builder(TestName());
@@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
.ConsumeValueOrDie();
@@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
auto lhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
@@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 9bf3767ca3..7501c6d957 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie();
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*LiteralUtil::CreateR0(input_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(input_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_value =
- std::move(*LiteralUtil::CreateR0(update_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(update_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_value =
- std::move(*LiteralUtil::CreateR0(expected_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(expected_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
absl::Span<const int> expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR1(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR3FromArray3D<NativeT>(values);
- LOG(INFO) << name << ":" << literal->ToString();
+ Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << literal.ToString();
}
};
@@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- auto input = ConstantLiteral(&builder, *input_literal);
+ auto input = ConstantLiteral(&builder, input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- stream.get(), *start_indices_literal, buffer));
+ stream.get(), start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 5116e60ca6..b08ece0e63 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index bf1de02ba9..51b50d456e 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
i < known_incorrect_range.second) {
// If the operation is known to be buggy on a specific input clamp that
// input to 0 under the assumption that the op is at least correct on 0.
- input_literal->Set({i - begin}, 0.0f);
+ input_literal.Set({i - begin}, 0.0f);
} else {
- input_literal->Set({i - begin}, tensorflow::bit_cast<float, int>(i));
+ input_literal.Set({i - begin}, tensorflow::bit_cast<float, int>(i));
}
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
enqueue_op(&builder, input);
std::vector<float> expected_result;
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(evaluate_op(input_literal->Get<float>({i})));
+ expected_result.push_back(evaluate_op(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7cb2f0cedf..9c94acb437 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -117,9 +117,9 @@ class FusionTest : public HloTestBase {
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
}
@@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
@@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
// Every element of result should be y = x^2 = 4.0.
for (int i = 0; i < rand_dim0_size; ++i) {
for (int j = 0; j < dim1_size; ++j) {
- EXPECT_EQ(4.0, result->Get<float>({i, j}));
+ EXPECT_EQ(4.0, result.Get<float>({i, j}));
}
}
}
@@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
@@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
@@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
@@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
@@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
@@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
@@ -782,19 +782,17 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> operand =
- LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ Literal operand = LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text, config));
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
- test_runner_.Execute(std::move(module), {operand.get()},
- /*run_hlo_passes=*/false));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
+ test_runner_.Execute(std::move(module), {&operand},
+ /*run_hlo_passes=*/false));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
- *result));
+ LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ result));
}
class FusionClientLibraryTest : public ClientLibraryTestBase {};
@@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
// where overflow is OK.
Array2D<uint32> arr(32, 32);
arr.FillUnique();
- std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({0, 1}));
- std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({1, 0}));
- XlaOp p0 = AddParam(*l1, &b);
+ XlaOp p0 = AddParam(l1, &b);
XlaOp sum = p0;
for (int i = 1; i < kNumParams; ++i) {
- auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
sum = sum + p0 * pN * pN;
}
@@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) {
auto param0_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
- client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
- client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param2_literal, device_ordinal)
.ConsumeValueOrDie();
// Build executable.
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 6d63498044..daa89398a6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -58,10 +58,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -79,10 +79,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -100,11 +100,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -122,11 +121,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -144,11 +143,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -166,13 +165,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -190,13 +188,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -214,10 +211,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -235,11 +232,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -257,9 +253,9 @@ ENTRY main {
slice_sizes={1, 0}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -281,11 +277,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -307,11 +303,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
+ Literal start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -333,11 +329,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -359,11 +355,11 @@ ENTRY main {
ROOT result = u32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -381,10 +377,10 @@ ENTRY main {
slice_sizes={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -402,9 +398,9 @@ ENTRY main {
slice_sizes={1}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -422,10 +418,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -446,10 +442,10 @@ ENTRY main {
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -470,11 +466,10 @@ ENTRY main {
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -495,11 +490,11 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -520,13 +515,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest,
@@ -548,13 +542,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -575,10 +568,10 @@ ENTRY main {
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -599,11 +592,10 @@ ENTRY main {
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> operand_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<xla::GlobalData>> result_data,
client_->ExecuteParallel(computation_instances));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
- *result_literal);
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 3df99aac7d..bdd4fd7e3d 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -136,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() {
return debug_options;
}
-StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments);
}
-std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_
.Execute(std::move(module), arguments,
/*run_hlo_passes=*/false)
.ValueOrDie();
}
-std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
}
@@ -188,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
TF_ASSIGN_OR_RETURN(auto reference,
reference_runner_.Execute(std::move(reference_module),
arguments, run_hlo_passes));
- return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
+ return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
error);
}
@@ -223,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompare(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -243,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -277,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return test_runner_
.Execute(std::move(module_or_status.ValueOrDie()),
fake_argument_ptrs, /*run_hlo_passes=*/true)
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 21d77c0cc4..0ae4bdc104 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -115,16 +115,16 @@ class HloTestBase : public ::testing::Test {
}
// Executes the given module and return the result as a Literal.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Same as above, except the module will be executed without running any HLO
// passes on it.
- std::unique_ptr<Literal> ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
- std::unique_ptr<Literal> ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Executes the given hlo module on two backends and compares results.
//
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 96f72212f3..43cca91f64 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -155,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
absl::Span<const NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -176,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
absl::Span<const NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -223,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -232,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 4151bfae03..b6f9b8156b 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,11 +31,11 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal literal = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal lhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(64).get(),
- LiteralUtil::CreateR0<int32>(42).get(),
+ Literal rhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(64),
+ LiteralUtil::CreateR0<int32>(42),
});
- CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
};
ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
}
@@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto two = LiteralUtil::CreateR0<float>(2);
auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
- CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
};
tensorflow::Env* env = tensorflow::Env::Default();
@@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
if (result.find("expected") != string::npos) {
- EXPECT_EQ("2", literal->ToString());
+ EXPECT_EQ("2", literal.ToString());
} else if (result.find("actual") != string::npos) {
- EXPECT_EQ("4", literal->ToString());
+ EXPECT_EQ("4", literal.ToString());
} else if (result.find("mismatches") != string::npos) {
- EXPECT_EQ("true", literal->ToString());
+ EXPECT_EQ("true", literal.ToString());
} else {
FAIL() << "unknown file in temporary directory: " << result;
}
@@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
- ::testing::AssertionResult result =
- LiteralTestUtil::Equal(*expected, *actual);
+ ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
EXPECT_THAT(result.message(),
::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
EXPECT_THAT(result.message(),
@@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
@@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
@@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b =
LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
- EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
- EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 237a4a361e..dbdd20daf0 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
DefaultExecutableBuildOptions(), options);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
@@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 1a823cf189..a99b43f469 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
auto y = ConstantR0<float>(&builder, 123.0f);
Add(x, y);
- auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
- LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
auto y = ConstantR1<float>(&builder, {});
Add(x, y);
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
- LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
@@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions().set_execution_profile(&profile));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
}
@@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
ScopedShapedBuffer result_colmaj =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with the parameter values in a different order.
ScopedShapedBuffer result_param_swap =
ExecuteLocallyOrDie(computation, {&y_array, &x_array});
- LiteralTestUtil::ExpectR2Near<float>(
- {{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_param_swap), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ ShapedBufferToLiteral(result_param_swap),
+ error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
@@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with row-major result layout.
@@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_rowmaj),
+ ShapedBufferToLiteral(result_rowmaj),
error_spec_);
}
@@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {2}));
+ LiteralSlice(result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
+ LiteralSlice(result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
+ LiteralSlice(result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralSlice(result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
options, DefaultExecutableRunOptions());
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
+ auto y_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
- auto x_buffer = LiteralToShapedBuffer(*x_literal);
- auto y_buffer = LiteralToShapedBuffer(*y_literal);
+ auto x_buffer = LiteralToShapedBuffer(x_literal);
+ auto y_buffer = LiteralToShapedBuffer(y_literal);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
@@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
+ Literal result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
+ LiteralSlice(result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
- LiteralSlice(*result_0_literal, {1}));
+ LiteralSlice(result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
- std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
+ Literal result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
- LiteralSlice(*result_1_literal, {0}));
+ LiteralSlice(result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
- LiteralSlice(*result_1_literal, {1}));
+ LiteralSlice(result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
// Feed in a tuple where each two-element vector element is {tuple_index,
// -tuple_index}.
- std::vector<std::unique_ptr<Literal>> arg_elements;
+ std::vector<Literal> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
- std::unique_ptr<Literal> arg_literal =
- LiteralUtil::MakeTupleOwned(std::move(arg_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
}
}
@@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::vector<std::unique_ptr<Literal>> outer_tuple_elements;
+ std::vector<Literal> outer_tuple_elements;
for (int i = 0; i < kFanout; ++i) {
- std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
+ std::vector<Literal> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
@@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
}
auto arg_literal =
LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
- LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
- error_spec_);
+ LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
+ LiteralSlice(result_literal, {i, j}),
+ error_spec_);
}
}
}
@@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
+ Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
- std::vector<std::unique_ptr<Literal>> arg_vector;
+ std::vector<Literal> arg_vector;
arg_vector.push_back(std::move(arg_literal));
arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
ShapeIndex index;
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(165.0,
- LiteralSlice(*result_literal, index));
+ LiteralSlice(result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
DefaultExecutableRunOptions().set_device_ordinal(d));
EXPECT_EQ(d, result.device_ordinal());
LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ ShapedBufferToLiteral(result));
}
}
}
@@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// As a check to verify that the computation ran of the device associated
// with the stream. This is a weak check, but stronger verification is hard.
EXPECT_EQ(d, result.device_ordinal());
- LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
}
}
@@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
+ Literal tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
- LiteralSlice(*tuple_literal, {0}));
+ LiteralSlice(tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
- LiteralSlice(*tuple_literal, {1}));
+ LiteralSlice(tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
@@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
// Array shapes.
- test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
- test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
- test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*LiteralUtil::MakeTuple({}));
+ test_to_device_and_back(LiteralUtil::MakeTuple({}));
// Non-nested tuples.
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(12223.0)}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}));
// Nested tuple.
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()})
- .get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
test_to_device_and_back(
- *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+ LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}),
+ LiteralUtil::CreateR0<int64>(123456789000LL)}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
Add(in, constant);
- std::unique_ptr<Literal> result;
+ Literal result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
}));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
// Join the thread.
thread.reset();
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
@@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal()));
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
// Benchmark that measures the overhead of the LocalClient API when running a
@@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) {
auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
- buffer));
+ ASSERT_IS_OK(
+ transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
const int kWarmups = 2;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index a8c68fc7fd..f90ef22d2d 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
.ConsumeValueOrDie();
}
-std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+Literal LocalClientTestBase::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
return local_client_->ShapedBufferToLiteral(shaped_buffer)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 90095c5d41..4027c7b124 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test {
// Construct and return a literal containing the array represented by
// shaped_buffer.
- std::unique_ptr<Literal> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Execute the given computation on the local client. With and without
// options.
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e195d4..4d327a6fe9 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
@@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
{0});
@@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1});
@@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1, 2});
@@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param2_literal =
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+ Literal param2_literal =
LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
- client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param2_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
- auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index edb592f43e..3f278115e0 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
});
Exp(data);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
{0.36788f, 1.64872f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
@@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
});
Map(&builder, {data}, add_half, {0, 1});
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
{-0.5f, 1.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
@@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
});
Max(lhs, rhs);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
{3.0f, -4.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6));
}
struct TestLinspaceMaxParam {
@@ -200,14 +200,12 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index c5e0b9b097..56aaeb0e68 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -114,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
+ Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
auto actual =
- ExecuteAndTransfer(std::move(hlo_module),
- {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
@@ -178,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect =
- std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
};
@@ -218,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
LiteralUtil::CreateR0<float>(1.0)),
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
LiteralUtil::CreateR0<int32>(4)));
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -247,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -280,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
}
const char* const kScalarOps = R"(
@@ -324,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -356,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -389,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
- LiteralUtil::CreateR1<float>({36, 64}),
- LiteralUtil::CreateR1<float>({66, 138})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -422,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -456,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -492,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -530,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest,
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
auto init2 = LiteralUtil::CreateR0<float>(6);
- std::unique_ptr<Literal> result = ExecuteNoHloPasses(
- std::move(module), {param.get(), init1.get(), init2.get()});
+ Literal result =
+ ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -565,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest,
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
LiteralUtil::CreateR3<Eigen::half>(
@@ -576,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest,
{Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)},
{Eigen::half(7), Eigen::half(8)}}})),
- *result));
+ result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
index 0a0426adcb..f2460822a6 100644
--- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
+++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
@@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
GetTupleElement(result_tuple, 0);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
VLOG(1) << "Transferring trip count to computation";
// Transfer number of iterations to Infeed.
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
// Pick up value from outfeed
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 1);
+ EXPECT_EQ(r.Get<int32>({}), 1);
}
VLOG(1) << "Writing data to infeed";
// Transfer some stuff to Infeed for use inside of loop.
TF_ASSERT_OK(local_client_->TransferToInfeed(
- *LiteralUtil::CreateR1<int32_t>({10, 20})));
+ LiteralUtil::CreateR1<int32_t>({10, 20})));
// Pick up value from outfeed
{
VLOG(1) << "Reading from body outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&xfeed_shape));
- EXPECT_EQ(r->Get<int32>({0}), 11);
- EXPECT_EQ(r->Get<int32>({1}), 21);
+ EXPECT_EQ(r.Get<int32>({0}), 11);
+ EXPECT_EQ(r.Get<int32>({1}), 21);
}
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 0);
+ EXPECT_EQ(r.Get<int32>({}), 0);
}
// Joins the thread
thread.reset();
- EXPECT_EQ(comp_result->Get<int32>({}), 0);
+ EXPECT_EQ(comp_result.Get<int32>({}), 0);
}
XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
@@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
}));
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&result_shape));
- EXPECT_EQ(r->Get<bool>({}), true);
+ EXPECT_EQ(r.Get<bool>({}), true);
// Join the thread
thread.reset();
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index cbeddffacf..6e98167739 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
@@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
@@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
@@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
Pad(AddParam(input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
@@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
- padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(3.14f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) {
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
- padding_config);
+ Pad(reduce, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index f6c762e7a4..dcb4c11c3c 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR0<float>(3.14159f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
@@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
@@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
@@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ Literal param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0,
ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
@@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
@@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
@@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
// Use both parameters
//
@@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
@@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ Parameter(&builder, 1, literal1.shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 =
- LiteralUtil::CreateR1<float>({10, 20, 30});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
// This add is unused.
Add(param1, param2);
@@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ Literal literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
- client_->TransferToServer(*literal).ConsumeValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ client_->TransferToServer(literal).ConsumeValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
+ Literal literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
sum_handle = Add(sum_handle, param);
}
@@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
- parameter_shapes.push_back(literal->shape());
+ parameter_shapes.push_back(literal.shape());
}
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
+ Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
+ std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
XlaOp bool_param =
- Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
+ Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
params.push_back(bool_param);
- parameter_shapes.push_back(bool_literal->shape());
+ parameter_shapes.push_back(bool_literal.shape());
auto init = Tuple(&builder, params);
@@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest,
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR1<float>({4, 5, 6}),
}))
.ConsumeValueOrDie();
@@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ Literal literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});
- const Shape original = literal->shape();
+ const Shape original = literal.shape();
{
// Reverse the layout present in original, and make that the layout of the
// literal.
@@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
original.layout().minor_to_major().begin(),
original.layout().minor_to_major().end());
std::reverse(original_layout.begin(), original_layout.end());
- *literal->mutable_shape_do_not_use()->mutable_layout() =
+ *literal.mutable_shape_do_not_use()->mutable_layout() =
LayoutUtil::MakeLayout(original_layout);
- ASSERT_EQ(2, literal->Get<float>({0, 1}));
+ ASSERT_EQ(2, literal.Get<float>({0, 1}));
}
// Use the original shape in building the computation.
XlaBuilder builder(TestName());
@@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
// Check that we got the off-diagonal value that we expected.
Array2D<float> expected(1, 1);
expected(0, 0) = 2;
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 5f322b768d..8f2c26f0ee 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -37,8 +37,7 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- std::unique_ptr<Literal> UniformTest(T a, T b, absl::Span<const int64> dims,
- int64 seed = 42);
+ Literal UniformTest(T a, T b, absl::Span<const int64> dims, int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
// of the given range size. `expected_count` is the number of times each
@@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase {
};
template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
- absl::Span<const int64> dims,
- int64 seed) {
+Literal PrngTest::UniformTest(T a, T b, absl::Span<const int64> dims,
+ int64 seed) {
XlaBuilder builder(TestName());
RngUniform(
ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -60,8 +58,8 @@ std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
SetSeed(seed);
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
- EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
- actual->EachCell<T>([=](absl::Span<const int64>, T value) {
+ EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions()));
+ actual.EachCell<T>([=](absl::Span<const int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
@@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
constexpr int64 count = 100;
for (int64 seed = 0; seed < count; ++seed) {
auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
- result->Literal::EachCell<bfloat16>(
- [&](absl::Span<const int64>, bfloat16 value) {
- int64 index = static_cast<int64>((value - low) / interval);
- counts[index]++;
- });
+ result.EachCell<bfloat16>([&](absl::Span<const int64>, bfloat16 value) {
+ int64 index = static_cast<int64>((value - low) / interval);
+ counts[index]++;
+ });
}
// Each bucket should have similar amount of counts. That is, not more than
// 10% of total counts. This mostly tests that we don't fall into a 1:2:2
@@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
- actual->EachCell<int32>(
+ actual.EachCell<int32>(
[&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
int64 sum = 0;
for (int32 i = 0; i < range_size; ++i) {
@@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
};
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
- client_->TransferToServer(*param0_literal));
+ client_->TransferToServer(param0_literal));
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto fn = build_sum_rng(builder);
Map(&builder, {param0}, fn, {0});
@@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
computation,
/*arguments=*/{param0_data.get()}, &execution_options));
- EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
- ShapeUtil::ElementsIn(param0_literal->shape()));
- for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
- EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
- EXPECT_LT(actual->data<float>()[i],
- param0_literal->data<float>()[i] + 1.0f);
+ EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()),
+ ShapeUtil::ElementsIn(param0_literal.shape()));
+ for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) {
+ EXPECT_GE(actual.data<float>()[i], param0_literal.data<float>()[i]);
+ EXPECT_LT(actual.data<float>()[i], param0_literal.data<float>()[i] + 1.0f);
}
}
@@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
ExecutionOptions execution_options2 = execution_options_;
execution_options2.set_seed(65);
- std::unique_ptr<Literal> result1;
+ Literal result1;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options1));
}
- std::unique_ptr<Literal> result2;
- std::unique_ptr<Literal> result3;
+ Literal result2;
+ Literal result3;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options1));
}
- std::unique_ptr<Literal> result4;
- std::unique_ptr<Literal> result5;
- std::unique_ptr<Literal> result6;
+ Literal result4;
+ Literal result5;
+ Literal result6;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_));
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index 9af9ea4a22..c9096fb29b 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ Literal reduce_input = LiteralUtil::CreateR4<float>(
{{ /*i0=0*/
{/*i1=0*/
{-0.246092796, -0.179497838, -0.161181688},
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 0916a07f4f..26e2bfde5c 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR1<float>({input_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
ReducePrecision(a, exponent_bits, mantissa_bits);
@@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// Abs doesn't affect resolution.
auto abs = Abs(a);
@@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 57f7fed61f..83997cdac2 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase {
}, 4);
// clang-format on
CHECK(ShapeUtil::Equal(
- literal_3d_->shape(),
+ literal_3d_.shape(),
ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
- << literal_3d_->shape().ShortDebugString();
+ << literal_3d_.shape().ShortDebugString();
}
// Runs an R1 => R0 reduction test with the given number of elements.
@@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase {
input_data[i] *= -1;
}
}
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(AsSlice(input_data));
+ Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (float item : input_data) {
@@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase {
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
+ Literal input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
bool expected = and_reduce;
for (bool item : input_data) {
@@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::array<bool, cols> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
@@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
// NativeT can be bool, and std::vector<bool> does not convert to
// Span.
@@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase {
reference_reduction_function_for_uints, unsigned_int_identity);
}
- std::unique_ptr<Literal> literal_2d_;
- std::unique_ptr<Literal> literal_3d_;
+ Literal literal_2d_;
+ Literal literal_3d_;
uint32 seed_ = 0xdeadbeef;
};
@@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
- MakeFakeLiteral(input_shape));
+ TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
- ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4));
+ ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
@@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3D(input_data);
+ Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 major = 0; major < 2; ++major) {
@@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
input.Each(
@@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
auto input_min = FLT_MAX;
@@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
@@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XlaBuilder builder("reduce_among_y");
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
@@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
@@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
@@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
+ input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
GetParam().reduce_dims);
@@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
auto a = ConstantR0<float>(&builder, 2.0f);
auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal =
- LiteralUtil::CreateR1<float>({1.0f, 4.0f});
+ Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b = Parameter(&builder, 0, b_literal.shape(), "b");
Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
@@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest {
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
- max_fn, {0});
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
+ {0});
ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
}
@@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1<float>(operand);
+ Literal input_literal = LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
- client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
&builder, expected, {input_global_data.get(), input_global_data2.get()},
ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index a1001296a1..63491a90bf 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
&builder_);
ReduceWindow(input, init,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
ReduceWindow(input, init_value,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
/*window_dimensions=*/{},
/*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
{}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
{}, ErrorSpec(0.00001));
}
@@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
Min(Add(lhs, rhs),
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
ReduceWindow(
input,
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
{}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
@@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- auto arg_literal = absl::make_unique<Literal>(shape);
- arg_literal->PopulateWithValue(1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ Literal arg_literal(shape);
+ arg_literal.PopulateWithValue(1.0f);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
@@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
Shape result_shape =
ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
- auto expected = absl::make_unique<Literal>(result_shape);
- expected->PopulateWithValue(27.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ Literal expected(result_shape);
+ expected.PopulateWithValue(27.0f);
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 1;
int stride = 8;
@@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 3;
int stride = 1;
@@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 8;
int stride = 5;
@@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
@@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -594,7 +588,7 @@ string R4ReduceWindowTestDataToString(
// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -614,11 +608,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], param.base_bounds[3]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(4);
@@ -627,7 +620,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
auto reducer = param.reducer;
if (use_bfloat16() && Product(param.window_bounds) > 128) {
@@ -659,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateFromArray(*expected);
+ Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
- input_literal->shape().element_type(),
- AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
- ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+ input_literal.shape().element_type(),
+ AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
+ ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
DefaultErrorSpec(), &expected_shape_with_layout);
}
};
@@ -988,7 +980,7 @@ string R3ReduceWindowTestDataToString(
param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1008,12 +1000,11 @@ TEST_P(R3ReduceWindowTest, DoIt) {
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
auto reducer = param.reducer;
if (use_bfloat16()) {
- input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal);
+ input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
if (Product(param.window_bounds) > 128) {
// To avoid numerical issues, force the reducer to be kMax for large bf16
// windows.
@@ -1021,9 +1012,9 @@ TEST_P(R3ReduceWindowTest, DoIt) {
}
}
- XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input");
+ XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
@@ -1035,7 +1026,7 @@ TEST_P(R3ReduceWindowTest, DoIt) {
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec());
+ ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(
@@ -1130,7 +1121,7 @@ string R2ReduceWindowTestDataToString(
param.layout[1], //
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1147,12 +1138,11 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(2);
for (int i = 0; i < 2; ++i) {
@@ -1162,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1178,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1332,7 +1322,7 @@ string R1ReduceWindowTestDataToString(
"__pad_high_", absl::StrJoin(param.pad_high, "x"),
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1352,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) {
const float kInitValue = 0.0f;
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ auto input_arg =
+ CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
std::vector<std::pair<int64, int64>> padding(1);
padding[0] = {param.pad_low[0], param.pad_high[0]};
@@ -1365,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1384,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index d891451381..5cf87e565b 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect 4.
- LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(4, literal);
}
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
@@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed,
/*arguments=*/{x_data.get(), y_data.get()},
@@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
.ConsumeValueOrDie();
// Expect 5.
- LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(5, literal);
}
TEST_F(ReplayTest, MapPlusTwoOverR1) {
@@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect result.
- LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, literal);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 17d12715f6..dedc95b5ae 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
auto a = Neg(parameter);
Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) {
Array2D<float> input_array(0, 3);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) {
Array2D<float> input_array(3, 0);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
auto input_literal =
LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 3});
auto expected_literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 3, 0, 0});
auto expected_literal =
LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{24, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
auto expected_literal = LiteralUtil::CreateFromArray(expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{8, 3});
@@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
{35, 36, 37},
{40, 41, 42},
{45, 46, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{8, 3});
@@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
{45, 16, 26},
{36, 46, 17},
{27, 37, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{2, 6, 2});
auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
t2x2x2x3.FillWithYX(*filler2x3);
auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 1, 1) = 7;
auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{2, 4});
auto expected_literal =
LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
Reshape(parameter, dimensions, {});
auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
- ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&b, expected_literal, {input.get()},
zero_error_spec_);
}
}
@@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {}, {});
EXPECT_THAT(
@@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {1}, {});
EXPECT_THAT(ExecuteToString(&b, {}),
@@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
LayoutUtil::MakeLayout({0, 1, 2, 3}));
// clang-format on
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
@@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0});
- std::unique_ptr<Literal> actual =
+ Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralUtil::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(expected);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
{{204, 205, 206, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
{{206, 7, 107, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
Array4D<float> input(2, 1, 1, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
Array4D<float> input(2, 1, 4, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
Array4D<float> input(5, 10, 2, 3);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
/*new_sizes=*/{5, 60});
@@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
*cell;
});
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
input_array.Each(
[&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
/*new_sizes=*/{7, 2, 3, 5});
XlaComputation computation = builder.Build().ConsumeValueOrDie();
@@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
{2, 3, 0, 1});
- std::unique_ptr<Literal> output_literal =
+ Literal output_literal =
client_
->ExecuteAndTransfer(computation, {input_data.get()},
&execution_options)
@@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
- EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
+ auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
+ EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
} else {
- EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
+ EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
}
}
@@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
+ ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
@@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
/*new_sizes=*/{2, 4, 3, 1});
@@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
{{16}, {20}, {24}}}});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()});
+ ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
@@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
@@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
@@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
@@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
@@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
- ->Relayout(input_literal->shape().layout());
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal)
+ .Relayout(input_literal.shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 74ded82ddf..4e55b0d7ac 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) {
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
- auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
+ auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto a = AddParam(*input_literal, &builder);
+ auto a = AddParam(input_literal, &builder);
Rev(a, spec.reversal);
- std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
+ Literal expected = input_literal.Clone();
std::vector<int64> output_indices(spec.input_dims.size());
- expected->EachCell<float>([&](absl::Span<const int64> indices, float) {
+ expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
for (int64 i = 0; i < indices.size(); ++i) {
output_indices[i] = indices[i];
}
- float value = input_literal->Get<float>(indices);
+ float value = input_literal.Get<float>(indices);
for (int64 dim : spec.reversal) {
output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
}
- expected->Set<float>(output_indices, value);
+ expected.Set<float>(output_indices, value);
});
- ComputeAndCompareLiteral(&builder, *expected, {});
+ ComputeAndCompareLiteral(&builder, expected, {});
}
INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index e692b8c5d5..091a5d2cac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -38,7 +38,7 @@ namespace {
class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
protected:
// Sends the literal to the server and retrieves it back.
- std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ Literal RoundTripToServer(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
return client_->Transfer(*data).ConsumeValueOrDie();
@@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
+ Literal actual =
reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0, actual->Get<float>({0}));
- EXPECT_EQ(24.0, actual->Get<float>({1}));
+ EXPECT_EQ(42.0, actual.Get<float>({0}));
+ EXPECT_EQ(24.0, actual.Get<float>({1}));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
@@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index a8193c2eac..cd5a531603 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
void RoundTripTest(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data).ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
+ Literal result = client_->Transfer(*data).ConsumeValueOrDie();
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, result));
}
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+ RoundTripTest(LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+ RoundTripTest(LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(
- *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ RoundTripTest(LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*LiteralUtil::MakeTuple({}));
+ RoundTripTest(LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
- LiteralUtil::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 07460a7e01..1dd937a6d0 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
+ Literal a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
@@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+ Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> c_data =
- client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+ client_->TransferToServer(c_literal).ConsumeValueOrDie();
- XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
- XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
- XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+ XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
+ XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
+ XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
Mul(Mul(a, b), c);
ComputeAndCompareR0<float>(&builder, 5.775f,
@@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(div_computation,
@@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend / divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(rem_computation,
@@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend % divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
- TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
+ Literal literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 1858dcea61..d20dba028a 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -62,13 +62,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
@@ -92,13 +90,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
@@ -123,13 +120,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
@@ -154,13 +149,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
@@ -185,13 +178,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
@@ -216,13 +208,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
@@ -247,13 +237,12 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
@@ -277,15 +266,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
@@ -309,15 +296,13 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
@@ -341,12 +326,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
@@ -370,13 +354,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
@@ -400,11 +382,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
@@ -429,12 +410,11 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
@@ -458,13 +438,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
@@ -488,13 +468,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
@@ -518,13 +498,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
@@ -548,12 +528,12 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
@@ -577,10 +557,10 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
@@ -604,10 +584,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
+ Literal updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index c9a58aefb4..a40c2d7de6 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
XlaBuilder builder(TestName());
auto original = ConstantR4FromArray4D(&builder, values);
Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
- &expected_literal->shape());
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
+ &expected_literal.shape());
}
struct R1Spec {
@@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase,
auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = Parameter(&builder, 0, literal->shape(), "p0");
+ auto original = Parameter(&builder, 0, literal.shape(), "p0");
Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
@@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase,
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
}
};
@@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) {
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = Parameter(&builder, 0, literal->shape(), "p0");
+ auto a = Parameter(&builder, 0, literal.shape(), "p0");
Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
@@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase,
XlaBuilder builder(TestName());
auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
- auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
+ auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 3ae31191a0..5155f0c652 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
// are not supported and uniqueness cannot be guaranteed if the number of
// elements exceeds the number of different values supported by the type.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> element,
+ Literal element,
MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
@@ -131,60 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<half>(&literal, engine,
no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<float>(&literal, engine,
no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<double>(&literal, engine,
no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(
- literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
+ literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
- std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +286,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// no constrained uses in the dataflow graph. If such constraints exist,
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+StatusOr<Literal> CreateLiteralForConstrainedUses(
const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
@@ -358,9 +351,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
- return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type());
case ConstantType::kOne:
- return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type());
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
@@ -374,34 +367,33 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param,
- std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+ const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random) {
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeArguments(module, engine.get());
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
arguments[i] =
MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index a260271b1b..b3c8a73905 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -57,8 +57,8 @@ class PseudorandomGenerator {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random = true);
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
+ bool pseudo_random = true);
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
@@ -84,14 +84,14 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
// TODO(b/79942829): Make interesting argument generation fast enough that using
// pseudo_random does not save any noticeable amount of time so that the
// parameter can be removed.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random = true);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random = true);
// Overload which accepts a random number generator. This enables generation of
// different random values with sequential calls to MakeFakeArguments by reusing
// the same generator.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine);
// Check that a given module satisfies various constraints before trying to
// execute it.
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 322c8ef090..181e5cbe29 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
@@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index c7eb9e2dbe..b34fd0f2e8 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
@@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(true);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(42, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(42, result.Get<int32>({}));
}
{
@@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(false);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(7, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(7, result.Get<int32>({}));
}
}
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 125513ddfd..d6641d257a 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase {
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR0<uint32>(42);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
+ LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
- *result);
+ result);
}
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1<float>(test_vector);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
+ LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1U8(test_string);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_EQ(result->GetR1U8AsString(), test_string);
+ EXPECT_EQ(result.GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
- LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
+ LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTuple({});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
+ Literal literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<complex64>(
- {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
- .get(),
- LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
@@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
// supported.
auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
}
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
- std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(456.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
-
- auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
- auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+ Literal literal1 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ Literal literal2 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(456.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
auto stream1 = stream_;
auto stream2 = stream_->GetOrCreateSubStream();
- std::unique_ptr<Literal> result1, result2;
+ Literal result1, result2;
// Round trip literals through device in multiple streams asynchronously.
for (int i = 0; i < kIterationCount; ++i) {
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
device_buffer1));
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
device_buffer2));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result1,
+ Literal this_result1,
transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result2,
+ Literal this_result2,
transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
result1 = std::move(this_result1);
result2 = std::move(this_result2);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
}
class TransferDeviceToHostBenchmark : public TransferManagerTest {
@@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
}
tensorflow::testing::StopTiming();
@@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
}
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index f2b3b49015..619d2a388b 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
@@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
- LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar1),
+ LiteralUtil::CreateR0<float>(constant_scalar2)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
- LiteralUtil::CreateR1<float>({}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
@@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
@@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>(constant_matrix).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>(constant_matrix),
+ LiteralUtil::CreateR1<float>(constant_vector)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
- LiteralUtil::CreateR0<bool>(!direction).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(direction),
+ LiteralUtil::CreateR0<bool>(!direction)});
- ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+ ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
@@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
- LiteralUtil::CreateR1<float>(vec2).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
@@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) {
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
- auto expected =
- LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::MakeTuple(
- {
- LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
- })
- .get(),
- LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+ }),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
@@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
- .get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<complex64>({1, 2}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
- {{10000, 20000}, {30000, 40000}}})
- .get()})
- .get()}))
+ {{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
- *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = absl::make_unique<Literal>(sum->shape());
- ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](absl::Span<const int64> indexes) {
- return sum->Get<complex64>(indexes) *
- (indexes[indexes.size() - 1] == 0
- ? complex64(1, 2)
- : complex64(1, -2));
- })
+ Literal prod(sum.shape());
+ ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+ return sum.Get<complex64>(indexes) *
+ (indexes[indexes.size() - 1] == 0
+ ? complex64(1, 2)
+ : complex64(1, -2));
+ })
.ok());
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
- LiteralUtil::CreateR0<complex64>({123, 456}).get()});
- ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+ LiteralUtil::CreateR0<complex64>({123, 456})});
+ ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
.ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
- auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+ auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ result));
}
// Disabled on interpreter due to lack of outfeed.
@@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest,
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
- {param0.get(), param1.get(), param1.get(),
- param0.get(), param4.get()})
+ {&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = Literal::CreateFromShape(expected->shape());
+ auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), *literal));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+ backend().default_stream_executor(), expected.shape(), literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 8f80a9f3e4..4fbd7f2fb1 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
{-inf<float>(), 0}});
Abs(arg);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper<complex64>() {
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
Sign(arg);
- std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
+ Literal expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper<complex64>() {
auto abs = Abs(arg);
Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
@@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) {
Add(sgnc, ConvertElementType(
Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, SignTestR1) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..7abd8651d5 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// have all reached 2.0.
auto expected_data =
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = LiteralUtil::MakeTuple({expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple(
+ {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
- auto expected = LiteralUtil::MakeTuple(
- {expected_counter.get(), expected_predicate.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+ auto expected =
+ LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
}
TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR0<int32>(7);
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({&expected_element, &expected_element});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
auto add1 = LiteralUtil::CreateR0<int32>(15);
auto add2 = LiteralUtil::CreateR0<int32>(16);
- auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(
- auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
XlaBuilder builder(TestName());
While(condition, body, ConstantR0<int32>(&builder, 0));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
ComputeAndCompareR0<int32>(&builder, 2, {});
}
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7fd42944de..db5a824de0 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 442e66321e..cdde88c135 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -39,8 +39,7 @@ limitations under the License.
namespace xla {
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
- absl::string_view path) {
+StatusOr<Literal> TextLiteralReader::ReadPath(absl::string_view path) {
CHECK(!absl::EndsWith(path, ".gz"))
<< "TextLiteralReader no longer supports reading .gz files";
std::unique_ptr<tensorflow::RandomAccessFile> file;
@@ -57,7 +56,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
: file_(file) {}
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+StatusOr<Literal> TextLiteralReader::ReadAllLines() {
tensorflow::io::RandomAccessInputStream stream(file_.get());
tensorflow::io::BufferedInputStream buf(&stream, 65536);
string shape_string;
@@ -74,9 +73,9 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
ShapeUtil::HumanString(shape));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
- result->PopulateWithValue<float>(fill);
+ result.PopulateWithValue<float>(fill);
std::vector<absl::string_view> pieces;
std::vector<absl::string_view> coordinates;
std::vector<int64> coordinate_values;
@@ -116,7 +115,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
"\"%s\"",
shape.dimensions_size(), coordinate_values.size(), line);
}
- result->Set<float>(coordinate_values, value);
+ result.Set<float>(coordinate_values, value);
}
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index b265640802..c40b43279f 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -41,7 +41,7 @@ class TextLiteralReader {
public:
// See class comment -- reads a file in its entirety (there must be only one
// literal in the text file path provided).
- static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path);
+ static StatusOr<Literal> ReadPath(absl::string_view path);
private:
// Ownership of file is transferred.
@@ -49,7 +49,7 @@ class TextLiteralReader {
// Parses a shape string on the first line, followed by lines of values to the
// end of the file.
- StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+ StatusOr<Literal> ReadAllLines();
// Owns the file being read
std::unique_ptr<tensorflow::RandomAccessFile> file_;
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 92f9b4f9f0..1fab4e3a08 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) {
tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
.ok());
- std::unique_ptr<Literal> literal =
- TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+ Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
- EXPECT_EQ(42.5, literal->Get<float>({0, 0, 0}));
- EXPECT_EQ(43.5, literal->Get<float>({0, 0, 1}));
- EXPECT_EQ(44.5, literal->Get<float>({0, 0, 2}));
- EXPECT_EQ(45.5, literal->Get<float>({0, 1, 0}));
- EXPECT_EQ(46.5, literal->Get<float>({0, 1, 1}));
- EXPECT_EQ(47.5, literal->Get<float>({0, 1, 2}));
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape()));
+ EXPECT_EQ(42.5, literal.Get<float>({0, 0, 0}));
+ EXPECT_EQ(43.5, literal.Get<float>({0, 0, 1}));
+ EXPECT_EQ(44.5, literal.Get<float>({0, 0, 2}));
+ EXPECT_EQ(45.5, literal.Get<float>({0, 1, 0}));
+ EXPECT_EQ(46.5, literal.Get<float>({0, 1, 1}));
+ EXPECT_EQ(47.5, literal.Get<float>({0, 1, 2}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 4ea02faffc..5cbaf2fcc1 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) {
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
- ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+ ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path));
string contents;
TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
&contents));
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index ba814af476..0c41f227b3 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -121,11 +121,10 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
}
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
- Literal::CreateFromProto(proto));
+ TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer data,
- client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
scoped_shaped_buffer_arguments.push_back(std::move(data));
}
for (const auto& argument : scoped_shaped_buffer_arguments) {
@@ -161,12 +160,12 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// --generate_fake_infeed is passed and there exists an infeed operation in
// the HloSnapshot.
absl::optional<tensorflow::thread::ThreadPool> pool;
- std::unique_ptr<Literal> data;
+ Literal data;
if (provide_infeed) {
data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
}
auto transfer_infeed = [&data, client]() {
- TF_CHECK_OK(client->TransferToInfeed(*data));
+ TF_CHECK_OK(client->TransferToInfeed(data));
};
if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
@@ -214,9 +213,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< "s: " << module.hlo().hlo_module().name();
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ TF_ASSIGN_OR_RETURN(Literal result_literal,
client->ShapedBufferToLiteral(*result));
- return std::move(*result_literal);
+ return result_literal;
}
StatusOr<HloSnapshot> ParseInputFile(const string& filename,
@@ -305,11 +304,11 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
result.ToString().c_str());
auto& snapshot = snapshots[i];
if (snapshot.has_result()) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
- literal->ToString().c_str());
+ literal.ToString().c_str());
}
}
}
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index 51909190a3..4f8852f8c1 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -40,8 +40,8 @@ int main(int argc, char **argv) {
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal_proto));
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
- fprintf(stderr, "%s\n", literal->ToString().c_str());
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
}
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 48c8374811..4b5c276bdf 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -36,16 +36,16 @@ int main(int argc, char **argv) {
LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
}
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
- LOG(INFO) << "literal: " << *literal;
- fprintf(stderr, "%s\n", literal->ToString().c_str());
- if (literal->shape().element_type() == xla::F32) {
- float min = *std::min_element(literal->data<float>().begin(),
- literal->data<float>().end());
- float max = *std::max_element(literal->data<float>().begin(),
- literal->data<float>().end());
+ LOG(INFO) << "literal: " << literal;
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
+ if (literal.shape().element_type() == xla::F32) {
+ float min = *std::min_element(literal.data<float>().begin(),
+ literal.data<float>().end());
+ float max = *std::max_element(literal.data<float>().begin(),
+ literal.data<float>().end());
fprintf(stderr, "min: %a=%f\n", min, min);
fprintf(stderr, "max: %a=%f\n", max, max);
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index dd329f1181..73b3589dbf 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -351,6 +351,7 @@ message DeviceAssignmentProto {
message LiteralProto {
Shape shape = 1;
repeated bool preds = 2;
+ bytes s8s = 15;
bytes u8s = 3;
repeated int32 s32s = 4;
repeated int64 s64s = 5;
@@ -364,7 +365,7 @@ message LiteralProto {
bytes f16s = 11;
bytes bf16s = 13;
repeated int64 sparse_indices = 14;
- // Next = 15
+ // Next = 16
}
message WindowDimension {
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
index 478c9663a7..54b06558ad 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -49,7 +49,7 @@ class XRTStateHelpers {
// TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
// OpKernel::Compute method.
static Status MakeLiteral(const xla::LiteralProto& proto,
- std::unique_ptr<xla::Literal>* literal) {
+ xla::Literal* literal) {
TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
return Status::OK();
}
@@ -173,7 +173,7 @@ class XRTAllocateOp : public OpKernel {
errors::InvalidArgument(
"Unable to parse allocation input to XLAAllocation"));
- std::unique_ptr<xla::Literal> literal;
+ xla::Literal literal;
OP_REQUIRES_OK(
ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
@@ -189,7 +189,7 @@ class XRTAllocateOp : public OpKernel {
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
- *literal, device_ref.backend(),
+ literal, device_ref.backend(),
device_ref.device_ordinal(), &allocation));
// Intern takes ownership of our reference to allocation.
@@ -381,11 +381,11 @@ class XRTReadLiteralOp : public OpKernel {
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
ctx, allocation->device_ordinal(), &device_ref));
- std::unique_ptr<xla::Literal> literal;
+ xla::Literal literal;
OP_REQUIRES_OK(
ctx, allocation->ToLiteral(device_ref.backend(),
device_ref.device_ordinal(), &literal));
- xla::LiteralProto literal_proto = literal->ToProto();
+ xla::LiteralProto literal_proto = literal.ToProto();
Tensor output(DT_STRING, TensorShape({}));
literal_proto.SerializeToString(&output.scalar<string>()());
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 5b8516bf1d..2952feb16a 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -52,44 +52,44 @@ string DeviceFromFlag() {
xla::LiteralProto TwoElementTuple() {
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
- return tuple->ToProto();
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ return tuple.ToProto();
}
xla::LiteralProto ScalarLiteral() {
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
- return scalar->ToProto();
+ return scalar.ToProto();
}
xla::LiteralProto NestedTuple() {
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
- auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
- return nested->ToProto();
+ auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
+ return nested.ToProto();
}
xla::LiteralProto MakeTuple0() {
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
- auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()});
- auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()});
- return nested1->ToProto();
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
+ auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
+ return nested1.ToProto();
}
-xla::LiteralProto FloatVector(gtl::ArraySlice<float> v) {
+xla::LiteralProto FloatVector(absl::Span<const float> v) {
auto array = xla::LiteralUtil::CreateR1<float>(v);
- return array->ToProto();
+ return array.ToProto();
}
bool CompareLiteralProtos(const xla::LiteralProto& a,
const xla::LiteralProto& b) {
auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
- bool equal = *l_a == *l_b;
+ bool equal = l_a == l_b;
if (!equal) {
LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
<< " != " << b.DebugString();
@@ -100,7 +100,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a,
bool CompareLiteralToLiteralProto(const xla::Literal& a,
const xla::LiteralProto& b) {
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
- bool equal = a == *l_b;
+ bool equal = a == l_b;
if (!equal) {
LOG(INFO) << "Literal and LiteralProto don't match "
<< a.ToProto().DebugString() << " != " << b.DebugString();
@@ -211,7 +211,7 @@ TEST(RawApiTest, SubBuffer) {
TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
- auto base_elements = base_literal->DecomposeTuple();
+ auto base_elements = base_literal.DecomposeTuple();
auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
xla::LiteralProto response_0;
EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
@@ -343,7 +343,7 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
- EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, CompileAndExecuteReturnTuple) {
@@ -392,8 +392,8 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
- auto expected = xla::LiteralUtil::MakeTuple({sum.get()});
- EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+ auto expected = xla::LiteralUtil::MakeTuple({&sum});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
} // namespace
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 2c3b07da58..d05a1e7dcb 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -174,7 +174,7 @@ XRTTupleAllocation::~XRTTupleAllocation() {
}
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
- std::unique_ptr<xla::Literal>* literal) {
+ xla::Literal* literal) {
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
index 42705688dd..73b5584e38 100644
--- a/tensorflow/compiler/xrt/xrt_state.h
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -135,7 +135,7 @@ class XRTTupleAllocation : public ResourceBase {
// Copies the allocation from device to host and returns it in literal.
Status ToLiteral(xla::Backend* backend, int device_ordinal,
- std::unique_ptr<xla::Literal>* literal);
+ xla::Literal* literal);
// True if none of the buffers in the allocation are aliased by any other live
// handle.
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 798f499870..d98a24994c 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -166,7 +166,9 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_kernels",
],
"//conditions:default": [],
- }),
+ }) + if_not_windows([
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
+ ]),
)
cc_library(
@@ -203,5 +205,7 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
],
"//conditions:default": [],
- }),
+ }) + if_not_windows([
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
+ ]),
)
diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD
index ad700ac4a0..e37ad7a758 100644
--- a/tensorflow/contrib/autograph/BUILD
+++ b/tensorflow/contrib/autograph/BUILD
@@ -21,11 +21,9 @@ py_library(
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
+ # This module is kept for backward compatibility only. To depend on AutoGraph,
+ # use //third_party/tensorflow/python/autograph instead.
deps = [
- "//tensorflow/contrib/autograph/impl",
- "//tensorflow/contrib/autograph/lang",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
- "//tensorflow/python:util",
+ "//tensorflow/python/autograph",
],
)
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index cc54da4daa..6ea2db72c4 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -1,5 +1,12 @@
# AutoGraph
+**NOTE: As tensorflow.contrib is being
+[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
+moving into TensorFlow core.
+
+The new code location is `tensorflow/python/autograph`.
+**
+
IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
AutoGraph is a Python to TensorFlow compiler.
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index 26e7a4a4d3..137bc59202 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -12,57 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Autograph compiles Python code into equivalent TensorFlow code.
+"""This is the legacy module for AutoGraph, kept for backward compatibility.
-Equivalent here means that they have the same effect when executed.
+New users should instead use `tensorflow.python.autograph`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(mdan): Bring only the relevant symbols to the top level.
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core.errors import GraphConstructionError
-from tensorflow.contrib.autograph.core.errors import TfRuntimeError
-from tensorflow.contrib.autograph.core.errors import improved_errors
-from tensorflow.contrib.autograph.impl.api import RunMode
-from tensorflow.contrib.autograph.impl.api import convert
-from tensorflow.contrib.autograph.impl.api import converted_call
-from tensorflow.contrib.autograph.impl.api import do_not_convert
-from tensorflow.contrib.autograph.impl.api import to_code
-from tensorflow.contrib.autograph.impl.api import to_graph
-from tensorflow.contrib.autograph.lang.directives import set_element_type
-from tensorflow.contrib.autograph.lang.directives import set_loop_options
-from tensorflow.contrib.autograph.lang.special_functions import stack
-from tensorflow.contrib.autograph.lang.special_functions import tensor_list
-from tensorflow.contrib.autograph.pyct.transformer import AutographParseError
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- # Main API
- 'RunMode',
- 'convert',
- 'converted_call',
- 'do_not_convert',
- 'to_code',
- 'to_graph',
- # Overloaded operators
- 'operators',
- # Errors
- 'improved_errors',
- 'GraphConstructionError',
- 'TfRuntimeError',
- # Python language "extensions"
- 'set_element_type',
- 'set_loop_options',
- 'stack',
- 'tensor_list',
- # Exceptions
- 'AutographParseError',
- # Utilities: to be removed
- 'utils',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
+from tensorflow.python.autograph import * # pylint:disable=wildcard-import
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 1375fddf2b..606da663dc 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -296,8 +296,9 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context,
@@ -709,8 +710,9 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
&buckets_list, stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 3b28ed77f3..51e0c2e431 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -862,6 +862,15 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
equality_split->set_feature_column(state->feature_column_group_id());
+ CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id)
+ << "Unexpected feature ID selected. "
+ << "Start feature ID: [" << start_index << "] "
+ << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1)
+ << "\nBest feature ID: [" << best_feature_idx << "] "
+ << feature_ids(best_feature_idx, 0) << ", "
+ << feature_ids(best_feature_idx, 1)
+ << "\nPartition IDS: " << partition_ids(start_index) << " "
+ << partition_ids(best_feature_idx);
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
index 90a0655201..e446c411a8 100644
--- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
@@ -448,8 +448,9 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
StatsAccumulatorScalarResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
@@ -512,8 +513,9 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
StatsAccumulatorTensorResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index 35d727482b..4da25298cb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-_BIAS_FEATURE_ID = -1
+_BIAS_FEATURE_ID = int(dtypes.int64.min)
class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index d9f03c3840..94ea7bc2eb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
gains[0], 0.00001)
def testGenerateFeatureSplitCandidatesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testGenerateFeatureSplitCandidatesMulticlass(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
[[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2])
@@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
@@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 5532bd026a..74b0ea6989 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -183,7 +183,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 1 | 3 |
@@ -320,7 +320,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(2, oblivious_split_info.children_parent_id[1])
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -458,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -546,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -633,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -708,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testGenerateFeatureSplitCandidatesWithTreeComplexity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -842,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -951,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1074,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1207,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1302,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1397,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1475,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
# No values in this feature column in this mini-batch.
values = array_ops.constant([], dtype=dtypes.float32)
@@ -1545,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testEmptyBuckets(self):
"""Test that reproduces the case when quantile buckets were empty."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_column = array_ops.sparse_placeholder(dtypes.float32)
# We have two batches - at first, a sparse feature is empty.
@@ -1638,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testDegenerativeCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# One data example only, one leaf and thus one quantile bucket.The same
# situation is when all examples have the same values. This case was
# causing before a failure.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index 4278a30ba9..46dfbdefeb 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -331,7 +331,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testObliviousEnsemble(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -1399,7 +1399,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([0, 0], result.eval())
def testObliviousTreeNonFinalized(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Depth 3 tree.
tree1 = tree_ensemble_config.trees.add()
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index b3e4c2e5f7..86fd5770a0 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -411,7 +411,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEmptyEnsembleObliviousCase(self):
"""Test growing an empty ensemble in the oblivious case."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -1620,7 +1620,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1810,7 +1810,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsembleWithEmptyNodesMiddleCase(self):
"""Test case: The middle existing leaves don't have examples."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
"""
@@ -2071,7 +2071,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsembleWithEmptyNodesBorderCase(self):
"""Test case: The first and last existing leaves don't have examples."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
"""
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 150d734db6..94b7f4f867 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -37,6 +37,7 @@ Checkpoint management:
Saving and restoring Python state:
@@NumpyState
+@@PythonStateWrapper
"""
from __future__ import absolute_import
@@ -45,6 +46,7 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.python_state import NumpyState
+from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035b6d..302d5cfb79 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import functools
+import six
import numpy
@@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase):
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(NumpyState, self).__getattribute__(name)
existing.array = value
@@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase):
super(NumpyState, self).__setattr__(name, value)
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+ """Wraps a Python object for storage in an object-based checkpoint."""
+
+ @abc.abstractmethod
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+ @abc.abstractmethod
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "py_state": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
+
+
+class _NumpyWrapper(PythonStateWrapper):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase):
self.array = array
def _serialize(self):
- """Callback for `PythonStringStateSaveable` to serialize the array."""
+ """Callback to serialize the array."""
string_file = BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase):
return serialized
def _deserialize(self, string_value):
- """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ """Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
- def _gather_saveables_for_checkpoint(self):
- """Specify callbacks for saving and restoring `array`."""
- return {
- "array": functools.partial(
- base.PythonStringStateSaveable,
- state_callback=self._serialize,
- restore_callback=self._deserialize)
- }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
index 0439a4755e..45494351ff 100644
--- a/tensorflow/contrib/checkpoint/python/python_state_test.py
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase):
save_state.a = numpy.ones([2, 2])
save_state.b = numpy.ones([2, 2])
save_state.b = numpy.zeros([2, 2])
+ save_state.c = numpy.int64(3)
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ self.assertEqual(3, save_state.c)
first_save_path = saver.save(prefix)
save_state.a[1, 1] = 2.
+ save_state.c = numpy.int64(4)
second_save_path = saver.save(prefix)
load_state = python_state.NumpyState()
@@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(first_save_path).initialize_or_restore()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(3, load_state.c)
load_state.a[0, 0] = 42.
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
loader.restore(first_save_path).run_restore_ops()
@@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(second_save_path).run_restore_ops()
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(4, load_state.c)
def testNoGraphPollution(self):
graph = ops.Graph()
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
index 493b3c6f1b..11e177cd0c 100644
--- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
@@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase):
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
feature_configs = {
"int64_col":
parsing_ops.FixedLenFeature(
@@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase):
num_rows = 10
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
index 9b6c056d6c..4f2ecbcb17 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase):
def testSetBlockCache(self):
cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gcs_config_ops.configure_gcs(sess, block_cache=cfg)
def testConfigureGcsHook(self):
@@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase):
'type': 'authorized_user'}
hook = gcs_config_ops.ConfigureGcsHook(credentials=creds)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None
hook.after_create_session(sess, None)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 0b79f718d4..789dab81ed 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -1,6 +1,10 @@
TensorFlow CMake build
======================
+CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all
+platforms. For details, see the
+[TensorFlow install guide](https://www.tensorflow.org/install/).
+
This directory contains CMake files for building TensorFlow on Microsoft
Windows. [CMake](https://cmake.org) is a cross-platform tool that can
generate build scripts for multiple build systems, including Microsoft
diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake
index ad2af01bc0..1a147e9c8e 100644
--- a/tensorflow/contrib/cmake/external/png.cmake
+++ b/tensorflow/contrib/cmake/external/png.cmake
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
include (ExternalProject)
+include (GNUInstallDirs)
set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive)
set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz)
@@ -35,7 +36,7 @@ if(WIN32)
endif()
endif()
else()
- set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a)
+ set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a)
endif()
set(png_HEADERS
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
index 9b4bf62710..3e25079e02 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_multipliers1 = session.run(
external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
multipliers1, 1.0))
@@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
]
multipliers = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(multipliers) < len(expected_multipliers):
multipliers.append(session.run(optimizer.lagrange_multipliers))
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
index 34c4543dca..df0eced631 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase):
matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
eigenvector1 = session.run(
swap_regret_optimizer._maximal_eigenvector_power_method(
standard_ops.constant(matrix1)))
@@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
[0.4, 0.3, 0.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
matrix))
@@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
[0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
standard_ops.exp(
swap_regret_optimizer.
@@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
@@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 8cfe142059..556d731840 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -61,7 +61,7 @@ class CrfTest(test.TestCase):
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
@@ -96,7 +96,7 @@ class CrfTest(test.TestCase):
]
for sequence_lengths, inputs, tag_bitmap in zip(
sequence_lengths_list, inputs_list, tag_bitmap_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_multitag_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
@@ -124,7 +124,7 @@ class CrfTest(test.TestCase):
for dtype in (np.int32, np.int64):
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -140,7 +140,7 @@ class CrfTest(test.TestCase):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
binary_score = crf.crf_binary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -176,7 +176,7 @@ class CrfTest(test.TestCase):
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
# Compare the dynamic program with brute force computation.
@@ -206,7 +206,7 @@ class CrfTest(test.TestCase):
"""
Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
@@ -226,7 +226,7 @@ class CrfTest(test.TestCase):
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_log_likelihoods = []
# Make sure all probabilities sum to 1.
@@ -254,7 +254,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -310,7 +310,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -351,7 +351,7 @@ class CrfTest(test.TestCase):
"""
Test that crf_decode works when sequence_length contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 5e6c1520a2..c378b1ce8d 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@CheckpointInputPipelineHook
@@CsvDataset
@@LMDBDataset
+@@Optional
@@RandomDataset
@@Reducer
@@SqlDataset
@@ -38,7 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
-
+@@get_next_as_optional
@@get_single_element
@@group_by_reducer
@@group_by_window
@@ -46,7 +47,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
-
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
@@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@sloppy_interleave
@@unbatch
@@unique
+
+@@AUTOTUNE
"""
from __future__ import absolute_import
@@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
+
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -107,10 +113,9 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 74107d5242..21ec50fb6b 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -49,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx,
ctx->input_list("record_defaults", &record_defaults_list));
for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index ae104d55bd..ad410e17fe 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset")
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
// `record_defaults` must be lists of scalars
for (size_t i = 8; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
}
return shape_inference::ScalarShape(c);
});
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b9320e5fef..ba202839b2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -72,12 +72,13 @@ py_test(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
@@ -276,6 +277,7 @@ py_test(
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
@@ -286,21 +288,6 @@ py_test(
)
py_test(
- name = "optimize_dataset_op_test",
- size = "small",
- srcs = ["optimize_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
name = "parsing_ops_test",
size = "small",
srcs = ["parsing_ops_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 67242fecfe..8e368bf2bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
@@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
@@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
@@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
@@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
@@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
@@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
@@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
@@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
@@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
@@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f2bd..48971f2ccc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
@@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
@@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
@@ -531,6 +531,45 @@ class BucketTest(test.TestCase):
self.assertEqual(batches, 15)
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
class BucketBySequenceLength(test.TestCase):
def testBucket(self):
@@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase):
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25, 35]
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.test_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
@@ -600,7 +658,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
@@ -637,7 +695,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
@@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase):
def testTupleElements(self):
- def elements_gen():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (x, y)
-
- def element_length_fn(x, y):
- del y
- return array_ops.shape(x)[0]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator=elements_gen,
- output_shapes=(tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([])),
- output_types=(dtypes.int32, dtypes.int32))
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8]))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 63bffd023f..f8e74e4583 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class CsvDatasetOpTest(test.TestCase):
- def _assert_datasets_equal(self, g, ds1, ds2):
+ def _get_next(self, dataset):
+ # Returns a no argument function whose result is fed to self.evaluate to
+ # yield the next element
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ get_next = it.get_next()
+ return lambda: get_next
+
+ def _assert_datasets_equal(self, ds1, ds2):
assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
'%s') % (ds1.output_shapes,
ds2.output_shapes)
assert ds1.output_types == ds2.output_types
assert ds1.output_classes == ds2.output_classes
- next1 = ds1.make_one_shot_iterator().get_next()
- next2 = ds2.make_one_shot_iterator().get_next()
- with self.session(graph=g) as sess:
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = sess.run(next1)
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- sess.run(next2)
- break
- op2 = sess.run(next2)
- self.assertAllEqual(op1, op2)
+ next1 = self._get_next(ds1)
+ next2 = self._get_next(ds2)
+ # Run through datasets and check that outputs match, or errors match.
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except (errors.OutOfRangeError, ValueError) as e:
+ # If op1 throws an exception, check that op2 throws same exception.
+ with self.assertRaises(type(e)):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+ self.assertAllEqual(op1, op2)
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase):
def _test_by_comparison(self, inputs, **kwargs):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
- with ops.Graph().as_default() as g:
- dataset_actual, dataset_expected = self._make_test_datasets(
- inputs, **kwargs)
- self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self._assert_datasets_equal(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
- sess,
dataset,
expected_output=None,
expected_err_re=None):
- nxt = dataset.make_one_shot_iterator().get_next()
if expected_err_re is None:
# Verify that output is expected, without errors
+ nxt = self._get_next(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
for value in expected_output:
- op = sess.run(nxt)
+ op = self.evaluate(nxt())
self.assertAllEqual(op, value)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
+ self.evaluate(nxt())
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
+ nxt = self._get_next(dataset)
while True:
try:
- sess.run(nxt)
+ self.evaluate(nxt())
except errors.OutOfRangeError:
break
@@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase):
# Convert str type because py3 tf strings are bytestrings
filenames = self._setup_files(inputs, linebreak, compression_type)
kwargs['compression_type'] = compression_type
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(sess, dataset, expected_output,
- expected_err_re)
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ self._verify_output_or_err(dataset, expected_output, expected_err_re)
def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4
@@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
record_defaults = [['']] * 3
@@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- with ops.Graph().as_default() as g:
- self._assert_datasets_equal(g,
- ds_actual.repeat(5).prefetch(1),
- ds_expected.repeat(5).prefetch(1))
+ self._assert_datasets_equal(
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
@@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase):
]]
file_path = self._setup_files(data)
- with ops.Graph().as_default() as g:
- ds = readers.make_csv_dataset(
- file_path, batch_size=1, shuffle=False, num_epochs=1)
- next_batch = ds.make_one_shot_iterator().get_next()
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ nxt = self._get_next(ds)
- with self.session(graph=g) as sess:
- result = list(sess.run(next_batch).values())
+ result = list(self.evaluate(nxt()).values())
self.assertEqual(result, sorted(result))
@@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase):
compression_type='ZLIB',
record_defaults=record_defaults)
+ def testCsvDataset_withScalarDefaults(self):
+ record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_with2DDefaults(self):
+ record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ 'Each record default should be at '
+ 'most rank 1.')
+ else:
+ err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9020a499c4..eb110324d1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(100):
for i in range(10):
@@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
freqs[sess.run(next_element)] += 1
@@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index e6883d53e0..f3968cdc15 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if error is None:
dense_val, sparse_val = sess.run(
element, feed_dict={
@@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(stop_t)
element = get_single_element.reduce_dataset(dataset, sum_reducer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sess.run(element, feed_dict={stop_t: stop})
self.assertEqual(stop * (stop - 1) / 2, value)
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index db2ab815ee..9c508d686d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase):
get_op = gen_dataset_ops.indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(materialize)
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
def testIdentityIndexedDataset(self):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
materialized = ds.materialize()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(materialized.initializer)
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
for i in range(16):
@@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(itr.initializer)
for i in range(16):
output = sess.run(n)
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 7a3215f6cc..b9e74dfddb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Empty input.
self._clear_coordination_events()
sess.run(
@@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
def testBlockLengthWithContentionSloppy(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
@@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
@@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(get_next)
def testErrorsInOutputFn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
next_element = iterator.get_next()
results = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2):
elements = []
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 7bc582ebaa..1cc5ddc9a2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(10): # 10 records.
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index dc9d56dd53..e8519381d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
@@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark):
end = time.time()
chained_deltas.append(end - start)
- fused_dataset = dataset = dataset.apply(
+ fused_dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul,
num_parallel_calls=num_calls,
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 61567bc8d7..83b723710c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -207,6 +208,31 @@ class MapDefunTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
+ def _assert_op_cancelled(self, sess, map_defun_op):
+ with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
+ sess.run(map_defun_op)
+
+ def testMapDefunWithParentCancellation(self):
+ # Checks that a cancellation of the parent graph is threaded through to
+ # MapDefunOp correctly.
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ del x
+ queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
+ # Blocking
+ return queue.dequeue_many(5)
+
+ c = constant_op.constant([1, 2, 3, 4, 5])
+ map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
+
+ with self.test_session() as sess:
+ thread = self.checkedThread(
+ self._assert_op_cancelled, args=(sess, map_defun_op))
+ thread.start()
+ time.sleep(0.1)
+ sess.close()
+ thread.join()
+
class MapDefunBenchmark(test.Benchmark):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index b299e0736f..7e9ea68047 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -7,6 +7,34 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "map_vectorization_test",
size = "small",
srcs = ["map_vectorization_test.py"],
@@ -46,16 +74,34 @@ py_test(
)
py_test(
- name = "latency_all_edges_test",
+ name = "model_dataset_op_test",
+ size = "medium",
+ srcs = ["model_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_op_test",
size = "small",
- srcs = ["latency_all_edges_test.py"],
+ srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..bd7b50b902
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test.TestCase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
index 1850b6921a..db380c02a9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -40,7 +40,7 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertEqual(1 * 1, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 6a7ef877f9..dde115925e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -74,7 +74,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(5):
result = sess.run(get_next)
r = x
@@ -131,7 +131,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
def _testMapAndFilter(self, dataset, function, predicate):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(10):
r = function(x)
if isinstance(r, tuple):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
new file mode 100644
index 0000000000..0a87d3e905
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -0,0 +1,177 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ModelDatasetTest(test.TestCase):
+
+ def testModelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelMapAndBatch(self):
+ batch_size = 16
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(10):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelInterleave(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=56, num_parallel_calls=56)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelNested(self):
+ k = 1024 * 1024
+ a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+ b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+ c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+ dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+ def f1(a, b, c):
+ x, y = a
+ return math_ops.matmul(x, y), b, c
+
+ def f2(a, b, c):
+ x, y = b
+ return a, math_ops.matmul(x, y), c
+
+ def f3(a, b, c):
+ x, y = c
+ return a, b, math_ops.matmul(x, y)
+
+ dataset = dataset.map(f1, num_parallel_calls=32)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f2, num_parallel_calls=16)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f3, num_parallel_calls=10)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ sess.run(get_next)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 089717156c..909da5aee0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
@@ -29,41 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
-
- def testAssertSuffix(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertSuffixInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertSuffixShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
+class OptimizeDatasetTest(test.TestCase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index f6c4a984b8..c4623bca73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase):
expected_values=None,
expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 361fe0dd39..0166ba0d44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
@@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = back_to_cpu_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase):
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 592642da0c..db8fe6aa1b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase):
self.assertEqual([tensor_shape.TensorShape([])] * 3,
[t.shape for t in get_next[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
@@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase):
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index fd00cdc5c6..ed75b27a44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -116,7 +116,7 @@ class ReadBatchFeaturesTest(
init_op = iterator.initializer
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c5cfddb72b..16b1441baa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
class_func=lambda c, _: c,
seed=27)).make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
while len(returned) < 4000:
returned.append(sess.run(get_next))
@@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 42cada0b97..dde678bd54 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase):
start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase):
make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 077abd6b30..440e48db30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase):
def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
get_next = ds_fn().make_one_shot_iterator().get_next()
outputs = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_outputs):
outputs.append(sess.run(get_next))
if verify_exhausted:
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 6b3e8e9f6e..90d18dca2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
window_stride=window_stride_t)).make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Slide: 1st batch.
actual = sess.run(get_next)
@@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 2c2cfbebff..52823d3fca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to verify statelessness of db operations.
sess.run(
init_op,
@@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetJoinQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetNullTerminator(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetReuseSqlDataset(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadEmptyResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidDriverName(self):
init_op = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidColumnName(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfQueryWithSyntaxError(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfInsertQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# table and place it in an `int32` tensor.
def testReadResultSetInt32VarCharColumnAsInt(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 43067b4245..e25570c5ad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 9a13acf8f0..2f5a44408f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1d70b16041..4c3353fe40 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase):
# TODO(rachelim): support sparse tensor outputs
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
while True:
try:
op1 = sess.run(next1)
@@ -52,9 +52,12 @@ class DatasetTestBase(test.TestCase):
dataset2,
exception_class,
replacements=None):
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ # We are defining next1 and next2 in the same line so that we get identical
+ # file:line_number in the error messages
+ # pylint: disable=line-too-long
+ next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
+ # pylint: enable=line-too-long
+ with self.cached_session() as sess:
try:
sess.run(next1)
raise ValueError(
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 4b08ec759d..8d335e87d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
thread_ids = []
try:
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index d79a842e7a..f994c8563f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index ff4d9b3260..6eaa0b1959 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
@@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredElement(structure, np.concatenate(
([5], shape), axis=0), dtype))
@@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredElement(None, np.concatenate(([5], shape), axis=0),
@@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredSparseElement(structure,
np.concatenate(([5], shape), axis=0),
@@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredSparseElement(None,
@@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping.window_dataset(len(shapes))).apply(
grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
self._structuredElement(
@@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
@@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shapes, dtype).apply(grouping.window_dataset(
len(shapes))).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredRaggedSparseElement(structure, shapes, dtype,
padded_shape))
@@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected = sess.run(
self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
@@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index c603ecc5ab..867ee2ba37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
def testWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer, feed_dict={
self.filename: self._createFile(),
@@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
@@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 4b45cc7e36..a14781cd93 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 6edc1d7990..099e10db92 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func,
bucket_batch_sizes,
padded_shapes=None,
padding_values=None,
- pad_to_bucket_boundary=False):
+ pad_to_bucket_boundary=False,
+ no_padding=False):
"""A transformation that buckets elements in a `Dataset` by length.
Elements of the `Dataset` are grouped together by length and then are padded
@@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func,
unknown size to bucket boundary minus 1 (i.e., the maximum length in each
bucket), and caller must ensure that the source `Dataset` does not contain
any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
Returns:
A `Dataset` transformation function, which can be passed to
@@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func,
def batching_fn(bucket_id, grouped_dataset):
"""Batch elements in dataset."""
- batch_size = batch_sizes[bucket_id]
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index fa1b851ad7..73840452df 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
# account for indexing) and transformation sequence.
@@ -46,6 +49,21 @@ def assert_next(transformations):
return _apply_fn
+def model():
+ """A transformation that models performance.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _ModelDataset(dataset)
+
+ return _apply_fn
+
+
def optimize(optimizations=None):
"""A transformation that applies optimizations.
@@ -97,6 +115,32 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
+class _ModelDataset(dataset_ops.Dataset):
+ """A `Dataset` that acts as an identity, and models performance."""
+
+ def __init__(self, input_dataset):
+ """See `optimize()` for details."""
+ super(_ModelDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.model_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
class _OptimizeDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 4c466781f7..785b395707 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
@@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat(
return dataset
-def make_tf_record_dataset(
- file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=None,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
"""Reads and optionally parses TFRecord files into a dataset.
Provides common functionality such as batching, optional parsing, shuffling,
@@ -300,8 +300,6 @@ def make_tf_record_dataset(
parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
drop_remainder=drop_final_batch))
- if prefetch_buffer_size is None:
- prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE
if prefetch_buffer_size == 0:
return dataset
else:
@@ -323,7 +321,7 @@ def make_csv_dataset(
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
num_parallel_reads=1,
sloppy=False,
num_rows_for_inference=100,
@@ -386,9 +384,10 @@ def make_csv_dataset(
shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
ensures better shuffling, but increases memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
- prefetch_buffer_size: An int specifying the number of feature batches to
- prefetch for performance improvement. Recommended value is the number of
- batches consumed per training step.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern,
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
reader_num_threads=1,
parser_num_threads=2,
sloppy_ordering=False,
@@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern,
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
- per training step (default is 1).
+ per training step. Defaults to auto-tune.
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 30e1992c01..91a27f97b7 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -76,7 +76,7 @@ We then compile the Keras model and pass the `MirroredStrategy` object in the
```python
model.compile(loss='mean_squared_error',
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
- distribute=strategy)
+ distribute=distribution)
```
To train the model we call Keras `fit` API using the input dataset that we
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index c524d8b394..aaecbb0eb1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -485,7 +485,6 @@ py_library(
srcs = ["single_loss_example.py"],
deps = [
":step_fn",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
@@ -708,19 +707,32 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "keras_test",
+py_library(
+ name = "keras_test_lib",
+ testonly = 1,
srcs = ["keras_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
+ deps = [
+ ":combinations",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "keras_test",
+ srcs = ["keras_test.py"],
+ additional_deps = [
+ ":keras_test_lib",
],
tags = [
"multi_and_single_gpu",
+ "no_pip",
"no_windows_gpu",
"notsan",
],
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 2301ba9233..244d1fcec8 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -50,10 +50,12 @@ from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
+from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
@@ -328,6 +330,10 @@ tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=5),
required_tpu=True)
+tpu_strategy_one_step = NamedDistribution(
+ "TPU", lambda: tpu_lib.TPUStrategy(
+ TPUClusterResolver(""), steps_per_run=1),
+ required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
@@ -343,17 +349,23 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
+ "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]
+adagrad_optimizer_v1_fn = NamedObject(
+ "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
+ adagrad_optimizer_v1_fn]
adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1))
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]
+adagrad_optimizer_v2_fn = NamedObject(
+ "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
+optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
+ adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index 0495134636..a84ef04196 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -63,7 +63,6 @@ def get_input_datasets():
# eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
- eval_ds = eval_ds.shuffle(100)
eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 3cee3e37a7..5f35e38189 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -18,9 +18,12 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
+from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
@@ -31,6 +34,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -63,6 +67,32 @@ def simple_functional_model():
return model
+def multi_inputs_multi_outputs_model():
+ input_a = keras.layers.Input(shape=(16,), name='input_a')
+ input_b = keras.layers.Input(shape=(16,), name='input_b')
+ input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+ dense = keras.layers.Dense(8, name='dense_1')
+
+ interm_a = dense(input_a)
+ # Read m
+ interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+ interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+ interm_b = dense(input_b)
+ merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+ output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+ output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+ model = keras.models.Model(
+ inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer=gradient_descent.GradientDescentOptimizer(0.001),
+ metrics={
+ 'dense_2': 'categorical_accuracy',
+ 'dense_3': 'categorical_accuracy'
+ })
+ return model
+
+
def get_ds_train_input_fn():
np.random.seed(_RANDOM_SEED)
(x_train, y_train), _ = testing_utils.get_test_data(
@@ -91,6 +121,68 @@ def get_ds_test_input_fn():
return dataset
+def get_multi_inputs_multi_outputs_data():
+ (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=3,
+ random_seed=_RANDOM_SEED)
+ (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+ (m_train, _), (m_test, _) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(8,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+
+ c_train = keras.utils.to_categorical(c_train)
+ c_test = keras.utils.to_categorical(c_test)
+ d_train = keras.utils.to_categorical(d_train)
+ d_test = keras.utils.to_categorical(d_test)
+
+ train_data = {
+ 'input_a': a_train,
+ 'input_b': b_train,
+ 'input_m': m_train,
+ 'output_c': c_train,
+ 'output_d': d_train
+ }
+ test_data = {
+ 'input_a': a_test,
+ 'input_b': b_test,
+ 'input_m': m_test,
+ 'output_c': c_test,
+ 'output_d': d_test
+ }
+
+ return (train_data, test_data)
+
+
+def batch_wrapper(dataset, batch_size, distribution):
+ # TPUs currently require fully defined input shapes, drop_remainder ensures
+ # the input will have fully defined shapes.
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ return dataset.batch(batch_size, drop_remainder=True)
+ else:
+ return dataset.batch(batch_size)
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step],
+ mode=['graph'])
+
+
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
def setUp(self):
@@ -99,6 +191,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.MakeDirs(self._base_dir)
self._config = run_config_lib.RunConfig(
tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+ self._dist = mirrored_strategy.MirroredStrategy(
+ devices=['/device:GPU:0', '/device:GPU:1'])
def tearDown(self):
writer_cache.FileWriterCache.clear()
@@ -152,6 +246,53 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+ def train_input_fn():
+ input_dict = {
+ 'input_a': train_data['input_a'],
+ 'input_b': train_data['input_b'],
+ 'input_m': train_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': train_data['output_c'],
+ 'dense_3': train_data['output_d']
+ }
+ return dataset_ops.Dataset.from_tensor_slices((input_dict,
+ output_dict)).batch(16)
+
+ def eval_input_fn():
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': test_data['output_c'],
+ 'dense_3': test_data['output_d']
+ }
+ return dataset_ops.Dataset.from_tensor_slices((input_dict,
+ output_dict)).batch(16)
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn)
+
+ def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,
+ eval_input_fn):
+ config = run_config_lib.RunConfig(
+ tf_random_seed=_RANDOM_SEED,
+ model_dir=self._base_dir,
+ train_distribute=self._dist)
+ with self.cached_session():
+ model = multi_inputs_multi_outputs_model()
+ est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
+ baseline_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+
def test_keras_optimizer_with_distribution_strategy(self):
dist = mirrored_strategy.MirroredStrategy(
devices=['/device:GPU:0', '/device:GPU:1'])
@@ -175,7 +316,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._config.model_dir)
-class TestWithDistributionStrategy(test.TestCase):
+class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_validating_dataset_input_tensors_with_shape_mismatch(self):
with self.cached_session():
@@ -215,7 +356,7 @@ class TestWithDistributionStrategy(test.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_on_same_dataset(self):
+ def test_calling_model_with_numpy_arrays(self):
with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
@@ -228,11 +369,44 @@ class TestWithDistributionStrategy(test.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ inputs = np.zeros((64, 3), dtype=np.float32)
+ targets = np.zeros((64, 4), dtype=np.float32)
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
+ validation_data=(inputs, targets))
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+ @combinations.generate(all_combinations())
+ def test_calling_model_on_same_dataset(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
+
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -241,6 +415,9 @@ class TestWithDistributionStrategy(test.TestCase):
validation_data=dataset, validation_steps=2)
model.predict(dataset, steps=2)
+ # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
+ # as clone_model's input_tensors argument only seems to accept list and not
+ # tuples or dict.
def test_fit_with_tuple_and_dict_dataset_inputs(self):
with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -282,7 +459,8 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- def test_fit_eval_and_predict_methods_on_dataset(self):
+ @combinations.generate(all_combinations())
+ def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
@@ -291,16 +469,13 @@ class TestWithDistributionStrategy(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
-
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -496,6 +671,8 @@ class TestWithDistributionStrategy(test.TestCase):
class LossMaskingWithDistributionStrategyTest(test.TestCase):
+ # TODO(priyag): Enable all strategies for this test. Currently it does not
+ # work for TPU due to some invalid datatype.
def test_masking(self):
with self.cached_session():
np.random.seed(1337)
@@ -519,24 +696,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
self.assertEqual(hist.history['loss'][0], 0)
-class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+class NormalizationLayerWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
- def test_batchnorm_correctness(self):
+ @combinations.generate(all_combinations())
+ def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
- strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
- '/device:GPU:0'])
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
- distribute=strategy)
+ distribute=distribution)
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ x = x.astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
dataset = dataset.repeat(100)
- dataset = dataset.batch(32)
+ dataset = batch_wrapper(dataset, 32, distribution)
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
out = model.predict(dataset, steps=2)
@@ -546,9 +724,11 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
-class CorrectnessWithDistributionStrategyTest(test.TestCase):
+class CorrectnessWithDistributionStrategyTest(test.TestCase,
+ parameterized.TestCase):
- def test_correctness(self):
+ @combinations.generate(all_combinations())
+ def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
@@ -557,43 +737,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
-
- # With DistributionStrategy
- dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
- dataset_with = dataset_with.batch(32)
- strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'])
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5),
- distribute=strategy)
- model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
- wts_with_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
- predict_dataset_with = predict_dataset_with.batch(2)
- predict_with_ds = model.predict(predict_dataset_with, steps=1)
- predict_with_ds = np.reshape(predict_with_ds, (4, 1))
-
- # Without DistributionStrategy
- dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+ def fit_and_predict(with_distribution=None):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ distribute=with_distribution)
+
+ batch_size = 64
+ if with_distribution:
+ batch_size //= with_distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
- dataset_without = dataset_without.batch(64)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5))
- model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
- wts_without_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
- x_predict, x_predict))
- predict_dataset_without = predict_dataset_without.batch(4)
- predict_without_ds = model.predict(predict_dataset_without, steps=1)
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+ # Running only 100 steps instead of the full dataset to keep test
+ # duration small.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+
+ weights = model.get_weights()
+
+ x_predict = [[1.], [2.], [3.], [4.]]
+ predict_batch_size = 4
+ if with_distribution:
+ predict_batch_size //= with_distribution.num_towers
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
+ x_predict))
+ predict_dataset = batch_wrapper(predict_dataset,
+ predict_batch_size, distribution)
+ predict_result = model.predict(predict_dataset, steps=1)
+ predict_result = np.reshape(predict_result, (4, 1))
+
+ return weights, predict_result
+
+ wts_with_ds, predict_with_ds = fit_and_predict(
+ with_distribution=distribution)
+ wts_without_ds, predict_without_ds = fit_and_predict(
+ with_distribution=None)
# Verify that the weights are the same within some limits of tolerance.
np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
@@ -602,5 +782,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
+# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1.
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index bdac4fb58c..ba147e7824 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -183,6 +183,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
"dense/kernel", "dense/bias", "beta1_power", "beta2_power",
"dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
"dense/bias/Adam_1"
+ ],
+ "Adagrad": [
+ "dense/kernel/Adagrad", "dense/kernel",
+ "dense/bias/Adagrad", "dense/bias"
]
}
variables = variables_map[optimizer_fn().get_name()]
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
index bb10b546a1..16799104e8 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -55,14 +55,14 @@ class PrefetchingOpsV2Test(test.TestCase):
next_element = iterator.get_next()
output = []
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
with self.cached_session() as sess:
- for _ in range(5):
+ for _ in range(4):
result = sess.run(next_element)
self.assertEqual(2, len(result))
output.extend(result)
- self.assertEquals(set(range(10)), set(output))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.assertEquals(set(range(8)), set(output))
def testPrefetchToTwoDevicesWithReinit(self):
if not test_util.is_gpu_available():
@@ -75,14 +75,14 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
with self.cached_session() as sess:
sess.run(iterator.initializer)
- for _ in range(5):
- sess.run(next_element)
- with self.assertRaises(errors.OutOfRangeError):
+ for _ in range(4):
sess.run(next_element)
sess.run(iterator.initializer)
- for _ in range(5):
+ for _ in range(4):
sess.run(next_element)
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index 5aa19cf6a9..09b351ffa4 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import step_fn
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn,
def dataset_fn():
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
- # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be
+ # TODO(isaprykin): batch with drop_remainder causes shapes to be
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
- return dataset.apply(
- batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True))
+ return dataset.batch(1, drop_remainder=True)
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 97c53ae2b9..9aadc634da 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -166,6 +166,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
+ tags = ["notap"],
)
cuda_py_test(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
index a7bd51430e..1e36b7ff9b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 196cc41335..13370497ce 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -22,7 +22,6 @@ import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -30,6 +29,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import test
bs = bijectors
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
index 25f29452c3..ba31697c58 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
from tensorflow.python.framework import dtypes
@@ -29,6 +28,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 6959b3e877..b4ad33cf6d 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
@@ -27,6 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.distributions import distribution as distribution_lib
# The following two lines are redundant, in a sense. The first enables
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index d8401801f2..74d9d04fc7 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d9110947ec..c6a23e4336 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index f1accaaa4c..49b9de0ab5 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import math
import numpy as np
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
@@ -36,6 +35,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
__all__ = [
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py
index 7d2274db9b..48d093e075 100644
--- a/tensorflow/contrib/eager/python/evaluator_test.py
+++ b/tensorflow/contrib/eager/python/evaluator_test.py
@@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"].numpy())
def testDatasetGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
@@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"])
def testWriteSummariesGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 529c99b37c..3acecd283c 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -1056,7 +1056,7 @@
"\n",
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
" result.append(index_word[predicted_id])\n",
"\n",
" if index_word[predicted_id] == '<end>':\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 40bc098724..e0d5e494d4 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -610,7 +610,7 @@
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index f1e1f99c57..560fc8c5a2 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -677,7 +677,7 @@
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
index fabd7b3e20..750bbc66f3 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
@@ -23,4 +23,4 @@ Attribution-ShareAlike License and is available at
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
This example was adapted from
- https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
+ https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
deleted file mode 100644
index 638c57d1c9..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-cuda_py_test(
- name = "scan_test",
- size = "small",
- srcs = ["scan_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "scan_graph_test",
- size = "small",
- srcs = ["scan_graph_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
deleted file mode 100644
index d4b8c8941e..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under graph mode execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- with tf.Session() as sess:
- sess.run(sum_op)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
deleted file mode 100644
index a02fc24c79..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under eager execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-
-if __name__ == '__main__':
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index dcc7b71d79..9d2d172752 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -216,7 +216,7 @@ class MetricsTest(test.TestCase):
self.assertEqual(m1.numer.name, "has_space/numer:0")
def testGraphWithPlaceholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
@@ -309,7 +309,7 @@ class MetricsTest(test.TestCase):
self.assertTrue(old_numer is m.numer)
def testMetricsChain(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
m1 = metrics.Mean()
m2 = metrics.Mean(name="m2")
update_m2 = m2(3.0)
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 437b3d965d..6db311d52d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":dnn_with_layer_annotations",
":early_stopping",
":export",
":exporter",
@@ -127,6 +128,61 @@ py_test(
)
py_library(
+ name = "dnn_with_layer_annotations",
+ srcs = ["python/estimator/dnn_with_layer_annotations.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model:utils",
+ ],
+)
+
+py_test(
+ name = "dnn_with_layer_annotations_test",
+ size = "medium",
+ srcs = ["python/estimator/dnn_with_layer_annotations_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan", # b/67510291
+ ],
+ deps = [
+ ":dnn_with_layer_annotations",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:dnn",
+ "//tensorflow/python/estimator:dnn_testing_utils",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:pandas_io",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/feature_column",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "dnn_linear_combined",
srcs = ["python/estimator/dnn_linear_combined.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 258860f263..78914ecaca 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
+from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
@@ -76,6 +77,8 @@ _allowed_symbols = [
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
'SavedModelEstimator'
+ 'DNNClassifierWithLayerAnnotations',
+ 'DNNRegressorWithLayerAnnotations',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
new file mode 100644
index 0000000000..152431d1b2
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -0,0 +1,434 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deep Neural Network estimators with layer annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import pickle
+
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.saved_model import utils as saved_model_utils
+
+
+class LayerAnnotationsCollectionNames(object):
+ """Names for the collections containing the annotations."""
+
+ UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features'
+ PROCESSED_FEATURES = 'layer_annotatons/processed_features'
+ FEATURE_COLUMNS = 'layer_annotations/feature_columns'
+
+ @classmethod
+ def keys(cls, collection_name):
+ return '%s/keys' % collection_name
+
+ @classmethod
+ def values(cls, collection_name):
+ return '%s/values' % collection_name
+
+
+def serialize_feature_column(feature_column):
+ if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
+ # We can't pickle nested functions, and we don't need the value of
+ # layer_creator in most cases anyway, so just discard its value.
+ args = feature_column._asdict()
+ args['layer_creator'] = None
+ temp = type(feature_column)(**args)
+ return pickle.dumps(temp)
+ return pickle.dumps(feature_column)
+
+
+def _to_any_wrapped_tensor_info(tensor):
+ """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`."""
+ any_buf = Any()
+ tensor_info = saved_model_utils.build_tensor_info(tensor)
+ any_buf.Pack(tensor_info)
+ return any_buf
+
+
+def make_input_layer_with_layer_annotations(original_input_layer, mode):
+ """Make an input_layer replacement function that adds layer annotations."""
+
+ def input_layer_with_layer_annotations(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
+ """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+ Generally a single example in training data is described with
+ FeatureColumns.
+ At the first layer of the model, this column oriented data should be
+ converted
+ to a single `Tensor`.
+
+ This is like tf.feature_column.input_layer, except with added
+ Integrated-Gradient annotations.
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ weight_collections: A list of collection names to which the Variable will
+ be added. Note that variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with
+ a mapping from `_FeatureColumn` to list of `Variable`s. For example,
+ after the call, we might have cols_to_vars = {_EmbeddingColumn(
+ categorical_column=_HashedCategoricalColumn( key='sparse_feature',
+ hash_bucket_size=5, dtype=tf.string), dimension=10): [<tf.Variable
+ 'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
+ shape=(5, 10)]} If a column creates no variables, its value will be an
+ empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated output
+ `Tensor`s.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: features and feature_columns have different lengths.
+ """
+
+ local_cols_to_output_tensors = {}
+ input_layer = original_input_layer(
+ features=features,
+ feature_columns=feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=local_cols_to_output_tensors)
+
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors = local_cols_to_output_tensors
+
+ if mode and mode == model_fn.ModeKeys.PREDICT:
+ # Only annotate in PREDICT mode.
+
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = features[key]
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ return input_layer
+
+ return input_layer_with_layer_annotations
+
+
+@contextlib.contextmanager
+def _monkey_patch(module, function, replacement):
+ old_function = getattr(module, function)
+ setattr(module, function, replacement)
+ yield
+ setattr(module, function, old_function)
+
+
+def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ n_classes=2,
+ weight_column=None,
+ label_vocabulary=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
+ """A classifier for TensorFlow DNN models with layer annotations.
+
+ This classifier is fuctionally identical to estimator.DNNClassifier as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNClassifer for example code for training and evaluating models
+ using this classifier.
+
+ This classifier is checkpoint-compatible with estimator.DNNClassifier and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNClassifier(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNClassifierWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ n_classes: Number of label classes. Defaults to 2, namely binary
+ classification. Must be > 1.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are already
+ encoded as integer or float within [0, 1] for `n_classes=2` and encoded as
+ integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there
+ will be errors if vocabulary is not provided and labels are string.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNClassifier with layer annotations.
+ """
+
+ original = dnn.DNNClassifier(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ n_classes=n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction)
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+ mode)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
+
+
+def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ label_dimension=1,
+ weight_column=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM,
+):
+ """A regressor for TensorFlow DNN models with layer annotations.
+
+ This regressor is fuctionally identical to estimator.DNNRegressor as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNRegressor for example code for training and evaluating models
+ using this regressor.
+
+ This regressor is checkpoint-compatible with estimator.DNNRegressor and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNRegressor(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNRegressorWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ label_dimension: Number of regression targets per example. This is the size
+ of the last dimension of the labels and logits `Tensor` objects
+ (typically, these have shape `[batch_size, label_dimension]`).
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNRegressor with layer annotations.
+ """
+
+ original = dnn.DNNRegressor(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction,
+ )
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+ mode)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
new file mode 100644
index 0000000000..2fe3d4c72e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
@@ -0,0 +1,611 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dnn_with_layer_annotations.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import dnn_with_layer_annotations
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import dnn_testing_utils
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import queue_runner
+
+try:
+ # pylint: disable=g-import-not-at-top
+ import pandas as pd
+ HAS_PANDAS = True
+except IOError:
+ # Pandas writes a temporary file during import. If it fails, don't use pandas.
+ HAS_PANDAS = False
+except ImportError:
+ HAS_PANDAS = False
+
+
+def _dnn_classifier_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
+ _dnn_regressor_fn)
+
+
+class DNNWithLayerAnnotationsClassifierEvaluateTest(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+def _dnn_regressor_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWithLayerAnnotationsTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def _getLayerAnnotationCollection(self, graph, collection_name):
+ keys = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.keys(
+ collection_name))
+ values = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.values(
+ collection_name))
+ if len(keys) != len(values):
+ raise ValueError('keys and values should have same length. lengths were: '
+ '%d and %d, and elements were %s and %s' %
+ (len(keys), len(values), keys, values))
+ return dict(zip(keys, values))
+
+ def _testAnnotationsPresentForEstimator(self, estimator_class):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(1,)),
+ feature_column.embedding_column(
+ feature_column.categorical_column_with_vocabulary_list(
+ 'y', vocabulary_list=['a', 'b', 'c']),
+ dimension=3)
+ ]
+ estimator = estimator_class(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ model_fn = estimator.model_fn
+
+ graph = ops.Graph()
+ with graph.as_default():
+ model_fn({
+ 'x': array_ops.constant([1.0]),
+ 'y': array_ops.constant(['a'])
+ }, {},
+ model_fn_lib.ModeKeys.PREDICT,
+ config=None)
+
+ unprocessed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .UNPROCESSED_FEATURES)
+ processed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .PROCESSED_FEATURES)
+ feature_columns = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .FEATURE_COLUMNS)
+
+ self.assertItemsEqual(unprocessed_features.keys(), ['x', 'y'])
+ self.assertEqual(2, len(processed_features.keys()))
+ self.assertEqual(2, len(feature_columns))
+
+ def testAnnotationsPresentForClassifier(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations)
+
+ def testAnnotationsPresentForRegressor(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations)
+
+ def _testCheckpointCompatibleWithNonAnnotatedEstimator(
+ self, train_input_fn, predict_input_fn, non_annotated_class,
+ annotated_class, prediction_key, estimator_args):
+ input_dimension = 2
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ estimator = non_annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ **estimator_args)
+
+ estimator.train(train_input_fn, steps=10)
+
+ predictions = np.array(
+ [x[prediction_key] for x in estimator.predict(predict_input_fn)])
+
+ annotated_estimator = annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ warm_start_from=self._model_dir,
+ **estimator_args)
+
+ annotated_predictions = np.array([
+ x[prediction_key] for x in annotated_estimator.predict(predict_input_fn)
+ ])
+
+ self.assertAllEqual(predictions.shape, annotated_predictions.shape)
+ for i, (a, b) in enumerate(
+ zip(predictions.flatten(), annotated_predictions.flatten())):
+ self.assertAlmostEqual(a, b, msg='index=%d' % i)
+
+ def testCheckpointCompatibleForClassifier(self):
+ n_classes = 2
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(
+ np.rint(data[:batch_size]).astype(np.int64), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNClassifier,
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PROBABILITIES,
+ estimator_args={'n_classes': n_classes})
+
+ def testCheckpointCompatibleForRegressor(self):
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNRegressor,
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PREDICTIONS,
+ estimator_args={'label_dimension': label_dimension})
+
+
+class DNNRegressorWithLayerAnnotationsEvaluateTest(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+def _queue_parsed_features(feature_map):
+ tensors_to_enqueue = []
+ keys = []
+ for key, tensor in six.iteritems(feature_map):
+ keys.append(key)
+ tensors_to_enqueue.append(tensor)
+ queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+ input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+ queue_runner.add_queue_runner(
+ queue_runner.QueueRunner(input_queue,
+ [input_queue.enqueue(tensors_to_enqueue)]))
+ dequeued_tensors = input_queue.dequeue()
+ return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+class DNNRegressorWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, y=data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ label_dimension = 1
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(data)
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+
+class DNNClassifierWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ n_classes = 3
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ input_dimension = 1
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(self._as_label(data))
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ input_dimension = 2
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, input_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(
+ value=self._as_label(datum[:1]))),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index bb5140aeb3..6aa62fb82e 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase):
observed *= num_rows / 3. if test_rows else num_cols / 2.
want_weight_sum = unobserved + observed
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
wals_model = factorization_ops.WALSModel(
input_rows=num_rows,
input_cols=num_cols,
@@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input_transposed(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase):
# trigger the more efficient ALS updates.
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase):
atol=1e-2)
def _run_test_als_transposed(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase):
rows = 15
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase):
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase):
def keep_index(x):
return not (x[0] + x[1]) % 4
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
row_wts = 0.1 + np.random.rand(rows)
col_wts = 0.1 + np.random.rand(cols)
data = np.dot(np.random.rand(rows, 3), np.random.rand(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
index 888c3c238c..112e4d289b 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
@@ -99,7 +99,7 @@ class GmmOpsTest(test.TestCase):
logging.info('Numpy took %f', time.time() - start_time)
start_time = time.time()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
op = gmm_ops._covariance(
constant_op.constant(
data.T, dtype=dtypes.float32), False)
@@ -120,7 +120,7 @@ class GmmOpsTest(test.TestCase):
graph = ops.Graph()
with graph.as_default() as g:
g.seed = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant(self.data, dtype=dtypes.float32)
loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
data, 'random', num_classes, random_seed=self.seed)
@@ -144,7 +144,7 @@ class GmmOpsTest(test.TestCase):
def testParams(self):
"""Tests that the params work as intended."""
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Experiment 1. Update weights only.
data = constant_op.constant(self.data, dtype=dtypes.float32)
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index 88eb9cf692..1ab5418fe4 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -232,7 +232,7 @@ class KMeansTest(KMeansTestBase):
self.assertEqual(features.shape, parsed_feature_dict.shape)
self.assertEqual(features.dtype, parsed_feature_dict.dtype)
# Then check that running the tensor yields the original list of points.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
parsed_points = sess.run(parsed_feature_dict)
self.assertAllEqual(self.points, parsed_points)
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 31820a18b4..9bdbd05015 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -336,7 +336,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -354,7 +354,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -440,7 +440,7 @@ class SweepHookTest(test.TestCase):
math_ops.logical_not(is_row_sweep_var)))
mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
is_sweep_done_var,
@@ -491,7 +491,7 @@ class StopAtSweepHookTest(test.TestCase):
train_op = state_ops.assign_add(completed_sweeps, 1)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([variables.global_variables_initializer()])
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index b1b5126d9e..45a67acb5b 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -24,11 +24,13 @@ from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
+from tensorflow.python.util.deprecation import deprecated
_ffmpeg_so = loader.load_op_library(
resource_loader.get_path_to_datafile('ffmpeg.so'))
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_audio(contents, file_format=None, samples_per_second=None,
channel_count=None, stream=None):
"""Create an op that decodes the contents of an audio file.
@@ -69,6 +71,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
ops.NotDifferentiable('DecodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def encode_audio(audio, file_format=None, samples_per_second=None):
"""Creates an op that encodes an audio file using sampled audio from a tensor.
@@ -95,6 +98,7 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
ops.NotDifferentiable('EncodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_video(contents):
"""Create an op that decodes the contents of a video file.
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
index 4f591367fd..77a424145a 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
@@ -82,7 +82,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -90,7 +90,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -103,7 +103,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -112,7 +112,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -146,7 +146,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -165,7 +165,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -189,7 +189,7 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -212,7 +212,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -266,7 +266,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 2479fe5b8d..b1820c10c8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -39,7 +39,7 @@ from tensorflow.python.platform import test
class LocalVariabletest(test.TestCase):
def test_local_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
@@ -55,7 +55,7 @@ class LocalVariabletest(test.TestCase):
class ReduceSumNTest(test.TestCase):
def test_reduce_sum_n(self):
- with self.test_session():
+ with self.cached_session():
a = constant_op.constant(1)
b = constant_op.constant([2])
c = constant_op.constant([[3, 4], [5, 6]])
@@ -119,13 +119,13 @@ class WithShapeTest(test.TestCase):
}))
def test_with_shape_invalid_expected_shape(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid rank",
tensor_util.with_shape, [[1], [2]],
constant_op.constant(1.0))
def test_with_shape_invalid_type(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid dtype",
tensor_util.with_shape, [1.1],
constant_op.constant([1.0]))
@@ -138,7 +138,7 @@ class WithShapeTest(test.TestCase):
constant_op.constant(1.0))
def test_with_shape_0(self):
- with self.test_session():
+ with self.cached_session():
value = 42
shape = [0]
unexpected_shapes = [[1], [2], [1, 1]]
@@ -150,7 +150,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_1(self):
- with self.test_session():
+ with self.cached_session():
value = [42]
shape = [1]
unexpected_shapes = [[0], [2], [1, 1]]
@@ -162,7 +162,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
shape = [2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -174,7 +174,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2x2(self):
- with self.test_session():
+ with self.cached_session():
value = [[42, 43], [44, 45]]
shape = [2, 2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -196,7 +196,7 @@ class WithShapeTest(test.TestCase):
np.testing.assert_array_equal(value, tensor_with_shape.eval())
def test_with_shape_none(self):
- with self.test_session():
+ with self.cached_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)
compatible_shape = [2, 2]
@@ -220,7 +220,7 @@ class WithShapeTest(test.TestCase):
@test_util.enable_c_shapes
def test_with_shape_partial(self):
- with self.test_session():
+ with self.cached_session():
tensor_partial_shape = array_ops.placeholder(dtypes.float32)
tensor_partial_shape.set_shape([None, 2])
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 9f5fee4542..e3c780ac1a 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -51,7 +51,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -59,7 +59,7 @@ class _LossesTest(object):
self._discriminator_real_outputs, self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -90,7 +90,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
@@ -98,7 +98,7 @@ class _LossesTest(object):
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -108,7 +108,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(logits, weights=weights)
self.assertEqual(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [[10.0, 4.4, -5.5, 3.6]],
@@ -125,7 +125,7 @@ class _LossesTest(object):
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
@@ -136,7 +136,7 @@ class _LossesTest(object):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -144,14 +144,14 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=self._weights, generated_weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs,
weights=constant_op.constant(self._weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -160,7 +160,7 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=weights, generated_weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._generator_kwargs.items()}
loss = self._g_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._discriminator_kwargs.items()}
loss = self._d_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase):
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._g_loss_fn(gen_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(
real_weights=self._weights, generated_weights=self._weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(
weights=constant_op.constant(self._weights), **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -404,7 +404,7 @@ class _PenaltyTest(object):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
self.assertEqual(self._expected_op_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
@@ -419,13 +419,13 @@ class _PenaltyTest(object):
def test_python_scalar_weight(self):
loss = self._penalty_fn(weights=2.3, **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
def test_scalar_tensor_weight(self):
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
@@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'])
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
one_sided=True)
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'],
target=2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(
loss,
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbfa11..25d74a8c23 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args):
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
@@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase):
discriminator_scope=self.discriminator_scope)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index c7b4e2faa8..be915ef96f 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -49,7 +49,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, y0, t)
self.assertIn('odeint', y_solved.name)
self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(t)
self.assertAllClose(y_true, y_solved)
@@ -62,7 +62,7 @@ class OdeIntTest(test.TestCase):
func = lambda y, t: k * y
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, 1.0 + 0.0j, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(k * t)
self.assertAllClose(y_true, y_solved)
@@ -74,7 +74,7 @@ class OdeIntTest(test.TestCase):
func = lambda t, y: (y - t)**2 + 1.0
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, np.float64(0.5), t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = 1.0 / (2.0 - t) + t
self.assertAllClose(y_true, y_solved)
@@ -96,7 +96,7 @@ class OdeIntTest(test.TestCase):
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, y0, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.zeros((len(t), 2, 1))
@@ -113,7 +113,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
self.assertEqual(y_solved.get_shape(),
tensor_shape.TensorShape(expected_shape))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
self.assertEquals(y_solved.shape, expected_shape)
@@ -126,7 +126,7 @@ class OdeIntTest(test.TestCase):
for t_dtype in [dtypes.float32, dtypes.float64]:
y0 = math_ops.cast(1.0, y0_dtype)
y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
expected = np.asarray(np.exp(t))
self.assertAllClose(y_solved, expected, rtol=1e-5)
@@ -148,13 +148,13 @@ class OdeIntTest(test.TestCase):
self.y0, [0, 1],
method='dopri5',
options={'max_num_steps': 0})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'max_num_steps'):
sess.run(y)
y = odes.odeint(self.func, self.y0, [1, 0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'monotonic increasing'):
sess.run(y)
@@ -164,7 +164,7 @@ class OdeIntTest(test.TestCase):
times0 = np.linspace(0, 10, num=11, dtype=float)
times1 = np.linspace(0, 10, num=101, dtype=float)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved_0, info_0 = sess.run(
odes.odeint(self.func, self.y0, times0, full_output=True))
y_solved_1, info_1 = sess.run(
@@ -179,7 +179,7 @@ class OdeIntTest(test.TestCase):
t = [0, 20]
kwargs = dict(
full_output=True, method='dopri5', options=dict(max_num_steps=2000))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_, info_0 = sess.run(
odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
_, info_1 = sess.run(
@@ -196,7 +196,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.9)
@@ -204,7 +204,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(0.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 10.0)
@@ -212,7 +212,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1e6))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.2)
@@ -229,13 +229,13 @@ class InterpolationTest(test.TestCase):
y_fit = array_ops.stack(
[odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
y_expected = f(times)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = sess.run(y_fit)
self.assertAllClose(y_expected, y_actual)
# attempt interpolation outside bounds
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(y_invalid)
@@ -251,7 +251,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [0., 1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
@@ -265,7 +265,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 7ede193029..124515e5a6 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -109,7 +109,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -122,7 +122,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -136,7 +136,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -150,7 +150,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -164,7 +164,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -179,7 +179,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -192,7 +192,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -208,7 +208,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -224,7 +224,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -241,7 +241,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -276,7 +276,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_scattered_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@@ -288,7 +288,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@@ -304,7 +304,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
- with self.test_session():
+ with self.cached_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@@ -316,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -329,7 +329,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@@ -358,7 +358,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -370,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -408,7 +408,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_hashed_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@@ -429,7 +429,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -467,7 +467,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_shape(self):
"""Verifies the shape of the output tensor."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -481,7 +481,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values(self):
"""Verifies the values in a trivial case."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .2, .3])
@@ -495,7 +495,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sampled_candidates(self):
"""Verifies the values for given sampled_candidates."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -520,7 +520,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sign_hash(self):
"""Verifies the values in a trivial case with hash_signs=True."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .1, .1])
@@ -537,7 +537,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_distributive_property(self):
"""Verifies the distributive property of matrix multiplication."""
- with self.test_session():
+ with self.cached_session():
params = constant_op.constant([.1, .2, .3])
sp_values_a = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
@@ -710,7 +710,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
@@ -749,7 +749,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -767,7 +767,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
diff --git a/tensorflow/contrib/layers/python/layers/encoders_test.py b/tensorflow/contrib/layers/python/layers/encoders_test.py
index e8528e9890..1a2aa710d5 100644
--- a/tensorflow/contrib/layers/python/layers/encoders_test.py
+++ b/tensorflow/contrib/layers/python/layers/encoders_test.py
@@ -34,14 +34,14 @@ def _get_const_var(name, shape, value):
class EncodersTest(test.TestCase):
def testBowEncoderSparse(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc = encoders.bow_encoder(docs, 4, 3)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseTensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
enc = encoders.bow_encoder(sparse_docs, 4, 3)
@@ -49,28 +49,28 @@ class EncodersTest(test.TestCase):
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseEmptyRow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 5)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([3, 5], enc.eval().shape)
def testBowEncoderDense(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([4, 3], enc.eval().shape)
def testBowEncoderSparseTensorDenseLookup(self):
- with self.test_session():
+ with self.cached_session():
docs = [[0, 1]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
with self.assertRaises(TypeError):
encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False)
def testBowEncodersSharingEmbeddings(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test')
enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
@@ -79,7 +79,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsInheritedScopes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
with variable_scope.variable_scope('test'):
enc_1 = encoders.bow_encoder(docs, 4, 3)
@@ -90,7 +90,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsSharedScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow')
variable_scope.get_variable_scope().reuse_variables()
@@ -100,7 +100,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncoderReuseEmbeddingsVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
@@ -111,7 +111,7 @@ class EncodersTest(test.TestCase):
self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval())
def testEmbedSequence(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index e6bbd86ab7..6fb4b9ff35 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -49,7 +49,7 @@ class TransformerTest(test.TestCase):
real_valued = feature_column.real_valued_column("price")
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops._Transformer(features).transform(real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[20.], [110], [-3]])
def testSparseRealValuedColumnIdentityTransformation(self):
@@ -60,7 +60,7 @@ class TransformerTest(test.TestCase):
features = {"rating": rating_tensor}
output = feature_column_ops._Transformer(features).transform(
sparse_real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.values.eval(), rating_tensor.values.eval())
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -80,7 +80,7 @@ class TransformerTest(test.TestCase):
[sparse_real_valued])
self.assertTrue(sparse_real_valued in output_dict)
output = output_dict[sparse_real_valued]
- with self.test_session():
+ with self.cached_session():
self.assertArrayNear(output.values.eval(), [4.0, 25.0], 1e-5)
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -97,7 +97,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[bucket])
self.assertEqual(len(output), 1)
self.assertIn(bucket, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[bucket].eval(), [[2], [3], [0]])
def testBucketizedColumnWithMultiDimensions(self):
@@ -109,7 +109,7 @@ class TransformerTest(test.TestCase):
"price": constant_op.constant([[20., 110], [110., 20], [-3, -3]])
}
output = feature_column_ops._Transformer(features).transform(bucket)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[2, 3], [3, 2], [0, 0]])
def testCachedTransformation(self):
@@ -118,7 +118,7 @@ class TransformerTest(test.TestCase):
# buckets 2, 3, 0
features = {"price": constant_op.constant([[20.], [110], [-3]])}
transformer = feature_column_ops._Transformer(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
transformer.transform(bucket)
num_of_ops = len(sess.graph.get_operations())
# Verify that the second call to transform the same feature
@@ -138,7 +138,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -161,7 +161,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -177,7 +177,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int64)
@@ -203,7 +203,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 2)
self.assertIn(hashed_sparse, output)
self.assertIn(wire_embedding, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[wire_embedding].indices.eval(),
wire_tensor.indices.eval())
self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
@@ -223,7 +223,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[keys_sparse])
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
@@ -241,7 +241,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(keys_sparse)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
@@ -264,7 +264,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int32)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -282,7 +282,7 @@ class TransformerTest(test.TestCase):
wire_tensor = constant_op.constant([[100, 0], [1, 25]])
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int32)
@@ -310,7 +310,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(weighted_ids, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
@@ -340,7 +340,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -362,7 +362,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -386,7 +386,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -408,7 +408,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -440,7 +440,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_language])
self.assertEqual(len(output), 1)
self.assertIn(country_language, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_language].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_language].values.eval(
@@ -467,7 +467,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_price])
self.assertEqual(len(output), 1)
self.assertIn(country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_price].values.eval()))
@@ -498,7 +498,7 @@ class TransformerTest(test.TestCase):
weights = column_to_variable[country_price][0]
grad = array_ops.squeeze(
gradients_impl.gradients(output, weights)[0].values)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertEqual(len(grad.eval()), 6)
@@ -537,7 +537,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[wire_country_price])
self.assertEqual(len(output), 1)
self.assertIn(wire_country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[wire_country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[wire_country_price].values.eval(
@@ -600,7 +600,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
columns = [one_hot_column, embedding_column, real_valued_column]
output = feature_column_ops.input_from_feature_columns(features, columns)
output_core = fc_core.input_layer(features, columns)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
@@ -626,7 +626,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
cols_to_outs = {}
feature_column_ops.input_from_feature_columns(
features, columns, cols_to_outs=cols_to_outs)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
for column in columns:
@@ -637,7 +637,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -650,7 +650,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -662,7 +662,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating, output)
@@ -673,7 +673,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0, 1, 2, -1],
[3, 4, 5, 6]])
features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating.astype(np.float32), output)
@@ -684,7 +684,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -698,7 +698,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -713,7 +713,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -729,7 +729,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0],
[1, 0, 0, 0, 1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -752,7 +752,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_column])
output_core = fc_core.input_layer(features, [one_hot_column])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
@@ -773,7 +773,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
@@ -794,7 +794,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
@@ -816,7 +816,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@@ -834,7 +834,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
@@ -852,7 +852,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [4, 10])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -878,7 +878,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features, [embedded_sparse], weight_collections=["my_collection_core"])
weights_core = ops.get_collection("my_collection_core")
grad_core = gradients_impl.gradients(output_core, weights_core)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
gradient_values = []
gradient_values_core = []
@@ -907,7 +907,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10])
@@ -935,7 +935,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# Makes sure that trying to use different initializers with the same
# embedding column explicitly fails.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Duplicate feature column key found for column: wire_embedding"):
@@ -961,7 +961,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -986,7 +986,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1005,7 +1005,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(crossed, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1016,7 +1016,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"wire": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: wire"):
variables_lib.global_variables_initializer().run()
@@ -1035,7 +1035,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"ids": ids_tensor, "weights": weights_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
@@ -1053,7 +1053,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"aaa": wire_tensor, "bbb": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: aaa_X_bbb"):
variables_lib.global_variables_initializer().run()
@@ -1080,7 +1080,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
hashed_sparse, 10, initializer=init_ops.constant_initializer(133.7))
output = feature_column_ops.input_from_feature_columns(
features, [real_valued, bucket, embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# size of output = 3 (real_valued) + 2 * 4 (bucket) + 10 (embedding) = 21
self.assertAllEqual(output.eval().shape, [3, 21])
@@ -1099,7 +1099,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values)
self.assertAllEqual(output.eval(), [[1.], [2.], [0.]])
@@ -1119,7 +1119,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
max_norm=0.5)
output = feature_column_ops.input_from_feature_columns(features,
[embedded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values * 0.5)
self.assertAllClose(output.eval(), [[0.5], [1.], [0.]])
@@ -1144,7 +1144,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# score: (sum of weights)
@@ -1236,7 +1236,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# There should be one trainable variables for sparse_2
self.assertEqual(1, len(variables_lib.trainable_variables()))
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_1_eval = output_1.eval()
output_2_eval = output_2.eval()
@@ -1295,7 +1295,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(measurement_input, model_inputs)
@@ -1305,7 +1305,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(
feature_column_ops.sequence_input_from_feature_columns(
features, [var_len_real_valued]))
@@ -1329,7 +1329,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(reshaped_measurements, model_inputs)
@@ -1350,7 +1350,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(measurement_input), model_inputs)
@@ -1373,7 +1373,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(reshaped_measurements), model_inputs)
@@ -1395,7 +1395,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1429,7 +1429,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1459,7 +1459,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1488,7 +1488,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1518,7 +1518,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights = ops.get_collection("my_collection")
gradient_tensor = gradients_impl.gradients(model_input_tensor,
embedding_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
@@ -1585,7 +1585,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
columns_to_tensors, model_input_columns)
self.assertEqual(dtypes.float32, model_input_tensor.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1640,7 +1640,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1654,7 +1654,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1676,7 +1676,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1695,7 +1695,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1716,7 +1716,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [crossed], num_outputs=5)
logits_core = fc_core.linear_model(features, [crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1730,7 +1730,7 @@ class WeightedSumTest(test.TestCase):
dense_shape=[2, 2])
features = {"wire": wire_tensor}
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating weighted sum for column: wire_embedding"):
variables_lib.global_variables_initializer().run()
@@ -1756,7 +1756,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
logits_core = fc_core.linear_model(features, [movies])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.initialize_all_variables().run()
lookup_ops.tables_initializer().run()
@@ -1776,7 +1776,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [real_valued], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1789,7 +1789,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1814,7 +1814,7 @@ class WeightedSumTest(test.TestCase):
features, [real_valued, bucket, hashed_sparse, crossed], num_outputs=5)
output_core = fc_core.linear_model(
features, [real_valued, bucket, hashed_sparse, crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1837,7 +1837,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [age, language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1877,7 +1877,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language], num_outputs=1))
# Assert that only a single weight is created.
self.assertEqual(len(variables), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1941,7 +1941,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1969,7 +1969,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1992,7 +1992,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [movies], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2026,7 +2026,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2083,7 +2083,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2124,7 +2124,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language, country_language],
num_outputs=1,
scope=scope))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2161,7 +2161,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, incomes], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2197,7 +2197,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, height, incomes], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2228,7 +2228,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2259,7 +2259,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket, country])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2290,7 +2290,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2326,7 +2326,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2365,7 +2365,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2389,7 +2389,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2404,7 +2404,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2419,7 +2419,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2440,7 +2440,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2452,7 +2452,7 @@ class WeightedSumTest(test.TestCase):
features = {"age": constant_op.constant([[10.], [20.], [30.], [40.]])}
output, _, bias = feature_column_ops.weighted_sum_from_feature_columns(
features, [feature_column.real_valued_column("age")], num_outputs=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
@@ -2466,7 +2466,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2490,7 +2490,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2516,7 +2516,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2556,7 +2556,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2585,7 +2585,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2651,7 +2651,7 @@ class ParseExampleTest(test.TestCase):
feature_columns=[bucket, wire_cast])
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
@@ -2713,7 +2713,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn("measurements", seq)
self.assertIsInstance(seq["measurements"], ops.Tensor)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
location_val, wire_cast_val, measurement_val = sess.run(
[ctx["location"], seq["wire_cast"], seq["measurements"]])
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index eaaf9f8d5f..d90d6ecf7f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -201,7 +201,7 @@ class FeatureColumnTest(test.TestCase):
b2 = feature_column_ops.input_from_feature_columns({
b[1]: input_tensor_c2
}, [b[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
b1_value = b1.eval()
b2_value = b2.eval()
@@ -230,7 +230,7 @@ class FeatureColumnTest(test.TestCase):
e1 = feature_column_ops.input_from_feature_columns({
e[0]: input_tensor_c1
}, [e[0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
d1_value = d1.eval()
e1_value = e1.eval()
@@ -340,7 +340,7 @@ class FeatureColumnTest(test.TestCase):
with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
one_hot_output = one_hot._to_dnn_input_layer(
id_tensor, output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
expected_shape = (id_tensor_shape[:output_rank - 1] + [vocab_size])
self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -376,7 +376,7 @@ class FeatureColumnTest(test.TestCase):
one_hot_output_shape = one_hot_output.get_shape().as_list()
expected_shape = id_tensor_shape[:-1] + [vocab_size]
self.assertEquals(expected_shape, one_hot_output_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -399,7 +399,7 @@ class FeatureColumnTest(test.TestCase):
expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0.,
0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
self.assertTrue(np.array_equal(one_hot_value, expected))
@@ -440,7 +440,7 @@ class FeatureColumnTest(test.TestCase):
}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[2., 6., 0.]], one_hot_tensor.eval())
@@ -451,7 +451,7 @@ class FeatureColumnTest(test.TestCase):
features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[1., 1., 0.]], one_hot_tensor.eval())
@@ -603,7 +603,7 @@ class FeatureColumnTest(test.TestCase):
real_valued_output = real_valued_column._to_dnn_input_layer(
constant_op.constant(real_valued_input, dtype=dtypes.float32),
output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
real_valued_eval = sess.run(real_valued_output)
expected_shape = (
input_shape[:output_rank - 1] +
@@ -797,7 +797,7 @@ class FeatureColumnTest(test.TestCase):
sparse_column.insert_transformed_feature(features)
sparse_output = features[sparse_column]
expected_shape = [batch_size, 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(sparse_output)
self.assertEquals(expected_shape, list(sparse_result.dense_shape))
@@ -1110,7 +1110,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
saved_embedding = embeddings.eval()
save.save(sess, checkpoint_path)
@@ -1131,7 +1131,7 @@ class FeatureColumnTest(test.TestCase):
embedding_col_initialized: input_tensor
}, [embedding_col_initialized])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_embedding = pretrained_embeddings.eval()
@@ -1176,7 +1176,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(assign_op)
saved_col_weights = col_weights[crossed_col][0].eval()
@@ -1201,7 +1201,7 @@ class FeatureColumnTest(test.TestCase):
}, [crossed_col_initialized], 1))
col_weights_from_ckpt = col_weights[crossed_col_initialized][0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_col_weights = col_weights_from_ckpt.eval()
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 04668f112d..a82d4c1951 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -3109,7 +3109,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
inputs: Tensor input
num_units: Specifies how many features will remain after maxout
in the `axis` dimension (usually channel).
- This must be multiple of number of `axis`.
+ This must be a factor of number of features.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
scope: Optional scope for variable_scope.
@@ -3128,7 +3128,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'.format(
num_channels, num_units))
- shape[axis] = -1
+ shape[axis] = num_units
shape += [num_channels // num_units]
# Dealing with batches with arbitrary sizes
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 52c9c4f3be..85af9de4e4 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -281,7 +281,7 @@ class BiasAddTest(test.TestCase):
def testCreate(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.bias_add(images)
self.assertEqual(output.op.name, 'BiasAdd/BiasAdd')
@@ -289,7 +289,7 @@ class BiasAddTest(test.TestCase):
def testCreateWithActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.bias_add(images, activation_fn=nn_ops.relu)
self.assertEqual(output.op.name, 'BiasAdd/Relu')
@@ -298,7 +298,7 @@ class BiasAddTest(test.TestCase):
def testCreateDimensions(self):
dims = (2, 3, 4)
shape = [5, 2, 3, 4]
- with self.test_session():
+ with self.cached_session():
for d in dims:
input_shape = shape[:d]
inputs = random_ops.random_uniform(input_shape, seed=1)
@@ -311,7 +311,7 @@ class BiasAddTest(test.TestCase):
class ConvolutionTest(test.TestCase):
def testInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'Convolution expects input with rank 5, got 4'):
@@ -323,14 +323,14 @@ class ConvolutionTest(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
layers_lib.convolution2d(images, 32, 3, data_format='CHWN')
def testCreateConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -342,7 +342,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNCHW(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW')
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -354,7 +354,7 @@ class ConvolutionTest(test.TestCase):
def testCreateSquareConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -362,7 +362,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithTensorShape(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, images.get_shape()[1:3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -370,7 +370,7 @@ class ConvolutionTest(test.TestCase):
def testCreateFullyConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
output = layers_lib.convolution2d(
images, 64, images.get_shape()[1:3], padding='VALID')
@@ -381,7 +381,7 @@ class ConvolutionTest(test.TestCase):
def testFullyConvWithCustomGetter(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
called = [0]
def custom_getter(getter, *args, **kwargs):
@@ -395,7 +395,7 @@ class ConvolutionTest(test.TestCase):
def testCreateVerticalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 1])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -407,7 +407,7 @@ class ConvolutionTest(test.TestCase):
def testCreateHorizontalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [1, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -417,7 +417,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithStride(self):
height, width = 6, 8
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], stride=2)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -427,7 +427,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
@@ -436,7 +436,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
@@ -453,14 +453,14 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithoutActivation(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], activation_fn=None)
self.assertEqual(output.op.name, 'Conv/BiasAdd')
def testCreateConvValid(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
@@ -468,7 +468,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithWD(self):
height, width = 7, 9
weight_decay = 0.01
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(weight_decay)
layers_lib.convolution2d(
@@ -481,7 +481,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNoRegularizers(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(
@@ -489,7 +489,7 @@ class ConvolutionTest(test.TestCase):
def testReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(len(variables.get_variables()), 2)
@@ -498,7 +498,7 @@ class ConvolutionTest(test.TestCase):
def testNonReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(len(variables.get_variables()), 2)
@@ -507,7 +507,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithWD(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
with arg_scope(
@@ -523,7 +523,7 @@ class ConvolutionTest(test.TestCase):
def testConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -539,7 +539,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -557,7 +557,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], rate=2, scope='conv1')
@@ -573,7 +573,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -587,7 +587,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -601,7 +601,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -612,7 +612,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 7, 9, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -651,7 +651,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 5, 7, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -670,7 +670,7 @@ class ConvolutionTest(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -688,7 +688,7 @@ class ConvolutionTest(test.TestCase):
padding='VALID',
activation_fn=None,
scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/BiasAdd')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -712,7 +712,7 @@ class Convolution2dTransposeTests(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'data_format has to be either NCHW or NHWC.'):
@@ -915,7 +915,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='SAME')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -929,7 +929,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='VALID')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -944,7 +944,7 @@ class Convolution2dTransposeTests(test.TestCase):
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -958,7 +958,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -971,7 +971,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -984,7 +984,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -997,7 +997,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1010,7 +1010,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 1], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1023,7 +1023,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 4], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1036,7 +1036,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 5], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1083,7 +1083,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=[2, 2], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
@@ -1095,7 +1095,7 @@ class Convolution2dTransposeTests(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 18, 22, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.conv2d_transpose(
@@ -1116,7 +1116,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=2, padding='VALID', scope='conv7')
self.assertEqual(output.op.name, 'conv7/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1135,7 +1135,7 @@ class Convolution2dTransposeTests(test.TestCase):
scope='conv7')
self.assertEqual(output.op.name, 'conv7/BiasAdd')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1146,7 +1146,7 @@ class Convolution2dTransposeTests(test.TestCase):
stride = 2
padding = 'VALID'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(input_size, seed=1)
output_deconv = layers_lib.conv2d_transpose(
images,
@@ -1184,7 +1184,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
expected = np.zeros((1, 10, 9, 1))
@@ -1201,7 +1201,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(
horz_gradients, feed_dict={
@@ -1225,7 +1225,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1245,7 +1245,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1267,7 +1267,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1283,7 +1283,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
expected = np.zeros((1, 9, 10, 1))
@@ -1306,7 +1306,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
@@ -1314,7 +1314,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConv1dShape(self):
width = 7
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, width, 3), seed=1)
output = layers_lib.convolution1d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -1322,7 +1322,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConvInferSpatialDims(self):
depth, height, width = 7, 9, 11
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
output = layers_lib.convolution(images, 32, [3])
self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
@@ -1344,7 +1344,7 @@ class DenseToSparseTest(test.TestCase):
sparse = _layers.dense_to_sparse(tensor)
dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
sparse.values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant = sess.run(dense)
self.assertAllEqual(expected_constant, constant)
@@ -1353,7 +1353,7 @@ class DropoutTest(test.TestCase):
def testCreateDropout(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.dropout(images)
self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
@@ -1362,7 +1362,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantTrue(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(True)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1370,7 +1370,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantFalse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(False)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1378,7 +1378,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithPlaceholder(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1387,7 +1387,7 @@ class DropoutTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1396,7 +1396,7 @@ class DropoutTest(test.TestCase):
def testDropout(self):
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1409,7 +1409,7 @@ class DropoutTest(test.TestCase):
def testDropoutSeed(self):
"""Test that providing the same seed produces the same result."""
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output1 = _layers.dropout(images, seed=1)
@@ -1418,7 +1418,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1431,7 +1431,7 @@ class DropoutTest(test.TestCase):
def testCreateFCFollowByDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(images, 50)
@@ -1445,7 +1445,7 @@ class DropoutTest(test.TestCase):
def testCreateFCWithDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(
@@ -1475,7 +1475,7 @@ class FlattenTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.flatten(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1484,7 +1484,7 @@ class FlattenTest(test.TestCase):
def testFlatten4D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.flatten(images)
@@ -1494,7 +1494,7 @@ class FlattenTest(test.TestCase):
def testFlatten3D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width), seed=1, name='images')
output = _layers.flatten(images)
@@ -1504,7 +1504,7 @@ class FlattenTest(test.TestCase):
def testFlattenBatchSize(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, height, width, 3))
@@ -1516,7 +1516,7 @@ class FlattenTest(test.TestCase):
def testUnknownDims(self):
height = width = depth = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, depth), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None))
@@ -1551,7 +1551,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs, new_rank)
static_shape = flattened_t.get_shape().as_list()
self.assertEqual(static_shape, expected_new_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_flattened, flattened)
@@ -1571,7 +1571,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs_t, new_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_indices, flattened.indices)
@@ -1641,7 +1641,7 @@ class FCTest(test.TestCase):
def testCreateFCWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(output.op.name, 'fc1/Relu')
@@ -1659,7 +1659,7 @@ class FCTest(test.TestCase):
def testCreateFcCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('fc1/weights'))
self.assertFalse(variables.get_variables('fc1/biases'))
_layers.fully_connected(inputs, 32, scope='fc1')
@@ -1669,7 +1669,7 @@ class FCTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(len(variables.get_variables('fc1')), 2)
_layers.fully_connected(inputs, 32, scope='fc1', reuse=True)
@@ -1678,7 +1678,7 @@ class FCTest(test.TestCase):
def testNonReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32)
self.assertEqual(len(variables.get_variables('fully_connected')), 2)
_layers.fully_connected(inputs, 32)
@@ -1713,14 +1713,14 @@ class FCTest(test.TestCase):
def testCreateFCWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, activation_fn=None)
self.assertEqual(output.op.name, 'fully_connected/BiasAdd')
def testCreateFCWithWD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
@@ -1732,7 +1732,7 @@ class FCTest(test.TestCase):
def testCreateFCWithBD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
bias_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, biases_regularizer=bias_decay)
@@ -1744,7 +1744,7 @@ class FCTest(test.TestCase):
def testCreateNoRegularizers(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
_layers.fully_connected(inputs, 32)
self.assertEqual(
@@ -1752,7 +1752,7 @@ class FCTest(test.TestCase):
def testReuseFCWithWD(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(
@@ -1768,7 +1768,7 @@ class FCTest(test.TestCase):
def testFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1786,7 +1786,7 @@ class FCTest(test.TestCase):
def testReuseFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1844,7 +1844,7 @@ class BatchNormTest(test.TestCase):
if dtype is None:
dtype = dtypes.float32
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(
dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused)
@@ -1866,7 +1866,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused)
@@ -1883,7 +1883,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpGammaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(
@@ -1901,7 +1901,7 @@ class BatchNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0]
@@ -1915,7 +1915,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
self.assertEqual(len(variables.get_model_variables()), 4)
@@ -1926,7 +1926,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariablesZeroDebias(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(
images, scale=True, zero_debias_moving_mean=True, fused=False)
@@ -1943,7 +1943,7 @@ class BatchNormTest(test.TestCase):
def testUpdatesCollection(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, updates_collections='my_update_ops')
update_layers = ops.get_collection('my_update_ops')
@@ -1971,7 +1971,7 @@ class BatchNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True, scope='bn')
_layers.batch_norm(images, scale=True, scope='bn', reuse=True)
@@ -1986,7 +1986,7 @@ class BatchNormTest(test.TestCase):
def testReuseUpdateOps(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with arg_scope([_layers.batch_norm], updates_collections='update_ops'):
_layers.batch_norm(images, scope='bn')
@@ -1996,7 +1996,7 @@ class BatchNormTest(test.TestCase):
def testCreateMovingVars(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_ = _layers.batch_norm(images)
moving_mean = variables.get_variables('BatchNorm/moving_mean')
@@ -2029,7 +2029,7 @@ class BatchNormTest(test.TestCase):
moving_variance = variables.get_variables_by_name('moving_variance')[0]
biased = variables.get_variables_by_name('biased')[0]
local_step = variables.get_variables_by_name('local_step')[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertAllClose(local_step.eval(), 0)
self.assertAllClose(moving_mean.eval(), [0] * channels)
@@ -2213,7 +2213,7 @@ class BatchNormTest(test.TestCase):
def _testEvalMovingVars(self, zero_debias_moving_mean=False):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2264,7 +2264,7 @@ class BatchNormTest(test.TestCase):
height, width = 3, 3
batch_size = 10
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (batch_size, height, width, channels)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2435,7 +2435,7 @@ class BatchNormTest(test.TestCase):
def testNoUpdatesWhenIsTrainingFalse(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2460,7 +2460,7 @@ class BatchNormTest(test.TestCase):
def testNoneUpdatesCollectionNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2647,7 +2647,7 @@ class BatchNormTest(test.TestCase):
def testCustomInitializer(self):
height, width = 3, 3
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = (np.ones((5, height, width, channels)) * 9.0).astype('f')
beta = init_ops.constant_initializer(
(np.ones(channels) * 5.0).astype('f'))
@@ -2728,7 +2728,7 @@ class BatchNormTest(test.TestCase):
def testBatchNormBeta(self):
# Test case for 11673
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
_layers.batch_norm(
a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
@@ -2739,7 +2739,7 @@ class BatchNormTest(test.TestCase):
def testVariablesAreFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float16)
_layers.batch_norm(images, scale=True)
@@ -2824,7 +2824,7 @@ class LayerNormTest(test.TestCase):
def testCreateOp(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.layer_norm(images)
self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm'))
@@ -2832,7 +2832,7 @@ class LayerNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images)
beta = variables.get_variables_by_name('beta')[0]
@@ -2842,7 +2842,7 @@ class LayerNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images, scope='ln')
_layers.layer_norm(images, scope='ln', reuse=True)
@@ -2853,7 +2853,7 @@ class LayerNormTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2940,7 +2940,7 @@ class GDNTest(test.TestCase):
def _runGDN(self, x, shape, inverse, data_format):
inputs = array_ops.placeholder(dtypes.float32, shape)
outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
y, = sess.run([outputs], {inputs: x})
return y
@@ -3152,14 +3152,14 @@ class MaxPool3DTest(test.TestCase):
class OneHotEncodingTest(test.TestCase):
def testOneHotEncodingCreate(self):
- with self.test_session():
+ with self.cached_session():
labels = np.array([0, 1, 2])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertEqual(output.op.name, 'OneHotEncoding/one_hot')
self.assertListEqual(output.get_shape().as_list(), [3, 3])
def testCollectOutputs(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
output = _layers.one_hot_encoding(
labels, num_classes=3, outputs_collections='outputs')
@@ -3168,14 +3168,14 @@ class OneHotEncodingTest(test.TestCase):
self.assertEqual(c_output, output)
def testOneHotEncoding(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertAllClose(output.eval(), one_hot_labels.eval())
def testOneHotEncodingInt32(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2], dtype=dtypes.int32)
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
@@ -3186,7 +3186,7 @@ class RepeatTests(test.TestCase):
def testRepeat(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3])
self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu')
@@ -3194,7 +3194,7 @@ class RepeatTests(test.TestCase):
def testRepeatWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.repeat(
@@ -3207,7 +3207,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvInt32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.int32, maxval=12345)
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
@@ -3215,7 +3215,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float32)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 2)
@@ -3224,7 +3224,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConv(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, None, [3, 3], 2)
self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
@@ -3233,7 +3233,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3245,7 +3245,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3257,7 +3257,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3268,14 +3268,14 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 6, scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
def testCreateConvWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 8, activation_fn=None)
@@ -3283,7 +3283,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID')
@@ -3291,7 +3291,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID', rate=2)
@@ -3299,7 +3299,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID')
@@ -3307,7 +3307,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousDepthwiseConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID', rate=2)
@@ -3316,7 +3316,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithWeightDecay(self):
random_seed.set_random_seed(0)
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3360,7 +3360,7 @@ class SeparableConv2dTest(test.TestCase):
def testReuseConvWithWeightDecay(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3419,7 +3419,7 @@ class SeparableConv2dTest(test.TestCase):
normalizer_params={},
scope='conv1')
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = np.random.rand(5, height, width, 3)
sess.run(init_op)
sess.run(net, feed_dict={images_placeholder: images})
@@ -3440,7 +3440,7 @@ class SeparableConv2dTest(test.TestCase):
def testSepConvNCHW(self):
for num_filters, correct_output_filters in zip((None, 5), (6, 5)):
- with self.test_session():
+ with self.cached_session():
batch, height, width = 4, 10, 12
kernel_dim, stride = 3, 2
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
@@ -3462,7 +3462,7 @@ class ScaleGradientTests(test.TestCase):
"""Simple tests of the scale_gradient function."""
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([42], np.float32)
gradient_scale = np.array([2], np.float32)
@@ -3513,7 +3513,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction = np.array([[self.low, self.high], [0.5, 0.5],
[self.high, self.low]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3529,7 +3529,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3547,7 +3547,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logit_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction, feed_dict=feed_dict)
self.assertAllClose(exp_prediction, prediction)
@@ -3575,7 +3575,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3586,7 +3586,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features, data_format='NCHW')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3613,7 +3613,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3637,7 +3637,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3669,7 +3669,7 @@ class SpatialSoftmaxTests(test.TestCase):
batch_size, nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features1}
tf_keypoints1 = sess.run(spatial_softmax, feed_dict)
@@ -3696,7 +3696,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3719,7 +3719,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3731,7 +3731,7 @@ class SpatialSoftmaxTests(test.TestCase):
spatial_softmax = _layers.spatial_softmax(features)
net = _layers.fully_connected(spatial_softmax, 10)
np_features = np.zeros(batch_shape, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
sess.run(net, feed_dict)
@@ -3741,7 +3741,7 @@ class StackTests(test.TestCase):
def testStackFullyConnected(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height * width * 3))
output = _layers.stack(images, _layers.fully_connected, [10, 20, 30])
self.assertEqual(output.op.name, 'Stack/fully_connected_3/Relu')
@@ -3749,7 +3749,7 @@ class StackTests(test.TestCase):
def testStackFullyConnectedFailOnReuse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('test', reuse=True):
images = np.random.uniform(size=(5, height * width * 3))
with self.assertRaises(ValueError):
@@ -3757,7 +3757,7 @@ class StackTests(test.TestCase):
def testStackRelu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.relu, [10, 20, 30])
@@ -3766,7 +3766,7 @@ class StackTests(test.TestCase):
def testStackElu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.elu, [10, 20, 30])
@@ -3775,7 +3775,7 @@ class StackTests(test.TestCase):
def testStackConvolution2d(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3788,7 +3788,7 @@ class StackTests(test.TestCase):
def testStackWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3817,7 +3817,7 @@ class UnitNormTests(test.TestCase):
del shape[dim]
expected = np.ones(shape)
- with self.test_session():
+ with self.cached_session():
actual = norms.eval()
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3849,7 +3849,7 @@ class UnitNormTests(test.TestCase):
norms = math_ops.sqrt(
math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
- with self.test_session():
+ with self.cached_session():
actual = norms.eval({image: placeholder_value})
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3875,7 +3875,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
for dim in range(len(x_shape)):
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3893,7 +3893,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
dim = [1, 2]
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3908,7 +3908,7 @@ class PoincareNormalizeTest(test.TestCase):
np.random.seed(1)
x_np = np.random.random_sample(x_shape).astype(np.float64)
for dim in range(len(x_shape)):
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -4117,7 +4117,7 @@ class LegacyFullyConnectedTest(test.TestCase):
# Empty x is common if someone masks their input with tf.boolean_mask in
# order to drop missing entries, and in a particular batch all entries are
# missing.
- with self.test_session():
+ with self.cached_session():
x = np.array([]).reshape(0, 3)
self.assertEqual(0, array_ops.size(x).eval())
y = _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax)
@@ -4131,7 +4131,7 @@ class LegacyFullyConnectedTest(test.TestCase):
y = _layers.legacy_fully_connected(x, 1)
# in the output we still only know the 2nd and 3rd dimensions statically.
self.assertEqual(y.get_shape().as_list(), [None, 4, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
# we can feed in input with first dimension 2
shape_value = sess.run(
@@ -4162,7 +4162,7 @@ class LegacyFullyConnectedTest(test.TestCase):
self._unknown_dim_invalid_input(last_dim=None)
def test_1d_invalid_input(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'rank of x must be at least 2 not: 1'):
x = constant_op.constant([[]], shape=[0])
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index 55272e5fd1..c8d3c91b10 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -106,7 +106,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform(image_shape, seed=1)
output_train = normalization.instance_norm(images, scope='IN')
output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
@@ -130,7 +130,7 @@ class InstanceNormTest(test.TestCase):
inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
output_op = normalization.instance_norm(
inputs, center=False, scale=False, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
@@ -287,7 +287,7 @@ class GroupNormTest(test.TestCase):
output_train = normalization.group_norm(images, groups=2, scope='IN')
output_eval = normalization.group_norm(images, groups=2, scope='IN',
reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
@@ -349,7 +349,7 @@ class GroupNormTest(test.TestCase):
channels_axis=channels_axis,
reduction_axes=reduction_axes,
mean_close_to_zero=mean_close_to_zero)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 0f037e24ad..29dede2a49 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -165,7 +165,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoise(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -182,7 +182,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoiseWithClipping(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -198,7 +198,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -213,7 +213,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testAdaptiveGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
clip_gradients = optimizers_lib.adaptive_clipping_fn()
train = optimizers_lib.optimize_loss(
@@ -234,7 +234,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(2, var_count)
def testGradientMultiply(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -433,7 +433,7 @@ class OptimizersTest(test.TestCase):
class AdaptiveClipping(test.TestCase):
def testAverages(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
scale = 2.
grad = array_ops.ones([3, 4]) * scale
log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
@@ -463,7 +463,7 @@ class AdaptiveClipping(test.TestCase):
self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
def testClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
spike = 1000.
multiplier = array_ops.placeholder(dtypes.float32, [], "multiplier")
step = array_ops.placeholder(dtypes.int32, [], "step")
diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py
index 07191eeda7..51faba30c7 100644
--- a/tensorflow/contrib/layers/python/layers/regularizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py
@@ -71,7 +71,7 @@ class RegularizerTest(test.TestCase):
with self.assertRaises(ValueError):
regularizers.l1_l2_regularizer(0.5, 0)
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -84,7 +84,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
@@ -93,7 +93,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
@@ -104,7 +104,7 @@ class RegularizerTest(test.TestCase):
self.assertEquals(loss, None)
def testL1L2RegularizerWithScope(self):
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -142,7 +142,7 @@ class RegularizerTest(test.TestCase):
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
expected = sum([2 * x for l in array_weights_list for x in l])
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(dummy_regularizer,
tensor_weights_list)
self.assertAllClose(expected, result.eval())
@@ -151,7 +151,7 @@ class RegularizerTest(test.TestCase):
regularizer = regularizers.l2_regularizer(0.0)
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(regularizer,
tensor_weights_list)
self.assertAllClose(0.0, result.eval())
@@ -161,7 +161,7 @@ class RegularizerTest(test.TestCase):
tensor_weights_list = [
constant_op.constant(x) for x in [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
]
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
regularizers.apply_regularization(non_scalar_regularizer,
tensor_weights_list)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index c34b5a8017..2c7463acc0 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase):
y1, y2 = block.forward(x1, x2)
x1_inv, x2_inv = block.backward(y1, y2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv])
@@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase):
x1, x2 = block.backward(y1, y2)
y1_inv, y2_inv = block.forward(x1, x2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
@@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase):
grads_rev = gradients_impl.gradients(loss_rev, wrt)
grads = gradients_impl.gradients(loss, wrt)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
@@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase):
for out, scope_vars in outputs_and_vars:
all_grads.append(gradients_impl.gradients(out, scope_vars))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = list(zip(*outputs_and_vars))[0]
outs, all_grads_val = sess.run([outputs, all_grads])
@@ -389,7 +389,7 @@ class RecomputeTest(test.TestCase):
layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(grads)
def testErrorOnClosedOverTensor(self):
diff --git a/tensorflow/contrib/layers/python/layers/summaries_test.py b/tensorflow/contrib/layers/python/layers/summaries_test.py
index a1ef06feec..2ec2af9d44 100644
--- a/tensorflow/contrib/layers/python/layers/summaries_test.py
+++ b/tensorflow/contrib/layers/python/layers/summaries_test.py
@@ -29,19 +29,19 @@ from tensorflow.python.platform import test
class SummariesTest(test.TestCase):
def test_summarize_scalar_tensor(self):
- with self.test_session():
+ with self.cached_session():
scalar_var = variables.Variable(1)
summary_op = summaries_lib.summarize_tensor(scalar_var)
self.assertEquals(summary_op.op.type, 'ScalarSummary')
def test_summarize_multidim_tensor(self):
- with self.test_session():
+ with self.cached_session():
tensor_var = variables.Variable([1, 2, 3])
summary_op = summaries_lib.summarize_tensor(tensor_var)
self.assertEquals(summary_op.op.type, 'HistogramSummary')
def test_summarize_activation(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = array_ops.identity(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -52,7 +52,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -64,7 +64,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu6(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu6(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -77,7 +77,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_collection_regex(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
array_ops.identity(var, name='Test1')
ops.add_to_collection('foo', array_ops.identity(var, name='Test2'))
diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py
index a9bd89532a..34f63f5d86 100644
--- a/tensorflow/contrib/layers/python/layers/utils_test.py
+++ b/tensorflow/contrib/layers/python/layers/utils_test.py
@@ -42,7 +42,7 @@ class ConstantValueTest(test.TestCase):
c = constant_op.constant(v)
value = utils.constant_value(c)
self.assertEqual(value, v)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(c.eval(), v)
def test_variable(self):
@@ -60,7 +60,7 @@ class ConstantValueTest(test.TestCase):
x = array_ops.identity(p)
value = utils.constant_value(p)
self.assertEqual(value, None)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(x.eval(feed_dict={p: v}), v)
@@ -80,7 +80,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -89,7 +89,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -99,7 +99,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -119,7 +119,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -128,7 +128,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -138,7 +138,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -151,7 +151,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_constant(self):
@@ -161,7 +161,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self):
@@ -171,7 +171,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
@@ -182,7 +182,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index d50750001e..b6c2cab64a 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -42,7 +42,7 @@ def _assert_sparse_tensor_value(test_case, expected, actual):
class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_tensor_1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1, 0, 2, 0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -53,7 +53,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_float(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1.5, 0.0, 2.3, 0.0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -64,7 +64,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_bool(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([True, False, True, False])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -75,7 +75,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b''])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -86,7 +86,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(
[b'qwe', b'', b'ewq', b''], ignore_value=b'qwe')
result = sess.run(st)
@@ -98,7 +98,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[1, 2, 0, 0], [3, 4, 5, 0]])
result = sess.run(st)
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -107,7 +107,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_3d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[[1, 2, 0, 0], [3, 4, 5, 0]],
[[7, 8, 0, 0], [9, 0, 0, 0]]])
result = sess.run(st)
@@ -117,7 +117,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_1d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]})
@@ -126,7 +126,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_3d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(
shape=[None, None, None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
@@ -142,7 +142,7 @@ class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_unknown_rank(self):
ph = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(ph)
result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]})
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -155,7 +155,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -167,7 +167,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_unsorted_indices(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[2, 0], [2, 2], [2, 1], [0, 0]],
values=[0, 1, 2, 3],
@@ -179,7 +179,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_in_the_end(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -191,7 +191,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_3d(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 2, 0], [0, 2, 1], [0, 2, 2]],
values=[0, 1, 2, 3],
@@ -207,7 +207,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
def test_indicators_to_sparse_ids_1d(self):
indicators = (0, 0, 1, 0)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0,),),
values=(2,),
@@ -220,7 +220,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(1, 0, 0, 1),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 3),
@@ -235,7 +235,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
((1, 0, 0, 1, 1), (0, 0, 1, 0, 0)),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=(
(0, 0, 0),
@@ -255,7 +255,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, dtype=dtypes.int16)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=np.array((2, 0, 3), dtype=np.int16),
@@ -269,7 +269,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value=-1)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -282,7 +282,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(('B', '', '', 'C'), ('', '', 'D', '')),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -296,7 +296,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value='x')
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -311,7 +311,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
indicators = array_ops.placeholder(
dtype=dtypes.int32, shape=(None, None, None))
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -325,7 +325,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
indicators = array_ops.placeholder(dtype=dtypes.int32)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 5e07b9313f..284a4f45f6 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase):
def test_unsupervised(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
inp, _ = feeder.input_builder()
feed_dict_fn = feeder.get_feed_dict_fn()
feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase):
def test_epoch(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
feeder.input_builder()
epoch = feeder.make_epoch_variable()
feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b7d9..5e90d1fa20 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase):
for index in range(2):
yield {'a': np.ones(1) * index}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase):
'label2': np.ones(1) * index - 64,
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key=['label', 'label2'],
@@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones((3, 3)) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase):
def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
x = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
failing_input_fn = generator_io.generator_input_fn(
x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
return np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
yield np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase):
}
y = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', np.arange(10)]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', 'target']
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
features = input_fn()
@@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e8f3..396539a76a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923db3..ff190110c1 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@ class OpsTest(test.TestCase):
"""Ops tests."""
def test_softmax_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
features = array_ops.placeholder(dtypes.float32, [None, 3])
labels = array_ops.placeholder(dtypes.float32, [None, 2])
weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@ class OpsTest(test.TestCase):
ids_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
ids = np.random.randint(0, n_embed, ids_shape)
- with self.test_session():
+ with self.cached_session():
embed_np = embeds[ids]
embed_tf = ops.embedding_lookup(embeds, ids).eval()
self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@ class OpsTest(test.TestCase):
def test_categorical_variable(self):
random_seed.set_random_seed(42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
embeddings = ops.categorical_variable(
cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61955..5a7e4ebfea 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase):
"""Sequence-to-sequence tests."""
def test_sequence_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
decoding = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
@@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase):
def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
- with self.test_session() as session:
+ with self.cached_session() as session:
x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase):
[[0, 0, 0], [0, 0, 0]]])
def test_rnn_decoder(self):
- with self.test_session():
+ with self.cached_session():
decoder_inputs = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
index 423dcce8de..8390ddda90 100644
--- a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
+++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class DecodeLibsvmOpTest(test.TestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [
"1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
"2 3:2.5 2:nan 1:0.105"
@@ -48,7 +48,7 @@ class DecodeLibsvmOpTest(test.TestCase):
[0, 0.105, np.nan, 2.5, 0, 0]])
def testNDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index a2d82cf800..553b116a3b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -30,7 +30,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTable(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -53,7 +53,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTableVectors(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
@@ -79,7 +79,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
output.eval())
def testExportSharded(self):
- with self.test_session():
+ with self.cached_session():
empty_key = -2
default_val = -1
num_shards = 2
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
index 237a6812b7..51c4f68543 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -36,13 +36,13 @@ class SparseFeatureColumnTest(TensorFlowTestCase):
self.assertTrue(isinstance(sfc.example_indices, ops.Tensor))
self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor))
self.assertEqual(sfc.feature_values, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
expected_feature_values = [1.0, 2.0, 3.0, 4.0]
sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
expected_feature_values)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 9317e2bb6e..52b994ee92 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -283,6 +283,7 @@ def generated_test_models():
"sparse_to_dense",
"split",
"sqrt",
+ "square",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
@@ -295,32 +296,70 @@ def generated_test_models():
"where",
]
-def gen_zip_test(name, test_name, **kwargs):
+def generated_test_conversion_modes():
+ """Returns a list of conversion modes."""
+
+ # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
+ return ["toco-extended", ""]
+
+def generated_test_models_all():
+ """Generates a list of all tests with the different converters.
+
+ Returns:
+ List of tuples representing (conversion mode, name of test).
+ """
+ conversion_modes = generated_test_conversion_modes()
+ tests = generated_test_models()
+ options = []
+ for conversion_mode in conversion_modes:
+ for test in tests:
+ if conversion_mode:
+ test += "_%s" % conversion_mode
+ options.append((conversion_mode, test))
+ return options
+
+def gen_zip_test(name, test_name, conversion_mode, **kwargs):
"""Generate a zipped-example test and its dependent zip files.
Args:
- name: Resulting cc_test target name
- test_name: Test targets this model. Comes from the list above.
- **kwargs: tf_cc_test kwargs.
+ name: str. Resulting cc_test target name
+ test_name: str. Test targets this model. Comes from the list above.
+ conversion_mode: str. Which conversion mode to run with. Comes from the
+ list above.
+ **kwargs: tf_cc_test kwargs
"""
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ flags = ""
+ if conversion_mode:
+ # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
+ # if conversion_mode == "pb2lite":
+ # toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite"
+ flags = "--ignore_toco_errors --run_with_extended"
+ kwargs["tags"].append("skip_already_failing")
+ kwargs["tags"].append("no_oss")
+ kwargs["tags"].append("notap")
+
gen_zipped_test_file(
name = "zip_%s" % test_name,
file = "%s.zip" % test_name,
+ toco = toco,
+ flags = flags,
)
tf_cc_test(name, **kwargs)
-def gen_zipped_test_file(name, file):
+def gen_zipped_test_file(name, file, toco, flags):
"""Generate a zip file of tests by using :generate_examples.
Args:
- name: Name of output. We will produce "`file`.files" as a target.
- file: The name of one of the generated_examples targets, e.g. "transpose"
+ name: str. Name of output. We will produce "`file`.files" as a target.
+ file: str. The name of one of the generated_examples targets, e.g. "transpose"
+ toco: str. Pathname of toco binary to run
+ flags: str. Any additional flags to include
"""
- toco = "//tensorflow/contrib/lite/toco:toco"
native.genrule(
name = file + ".files",
- cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco +
- " --zip_to_output " + file + " $(@D)"),
+ cmd = (("$(locations :generate_examples) --toco $(locations {0}) " +
+ " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)),
outs = [file],
tools = [
":generate_examples",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 9cf4bea73e..5e97b777fc 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -117,6 +117,7 @@ typedef enum {
kTfLiteBuiltinReduceMin = 89,
kTfLiteBuiltinFloorDiv = 90,
kTfLiteBuiltinReduceAny = 91,
+ kTfLiteBuiltinSquare = 92,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index fa43e6a024..be9d551ee4 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -25,6 +25,9 @@ extern "C" {
// TODO(aselle): Consider using "if this then that" for testing.
+// IMPORTANT: All new members of structs must be added at the end to ensure
+// backwards compatibility.
+
// Possible padding types (for convolutions)
typedef enum {
kTfLitePaddingUnknown = 0,
@@ -71,11 +74,15 @@ typedef struct {
} TfLitePoolParams;
typedef struct {
+ // Parameters for DepthwiseConv version 1 or above.
TfLitePadding padding;
int stride_width;
int stride_height;
int depth_multiplier;
TfLiteFusedActivation activation;
+ // Parameters for DepthwiseConv version 2 or above.
+ int dilation_width_factor;
+ int dilation_height_factor;
} TfLiteDepthwiseConvParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
index 48df68a654..ee3dff6792 100644
--- a/tensorflow/contrib/lite/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -146,7 +146,7 @@ void TfLiteIntArrayFree(TfLiteIntArray* v);
#define TF_LITE_ENSURE_OK(context, status) \
do { \
if ((status) != kTfLiteOk) { \
- return status; \
+ return kTfLiteError; \
} \
} while (0)
@@ -374,6 +374,11 @@ typedef struct TfLiteContext {
// WARNING: This is an experimental interface that is subject to change.
void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
TfLiteExternalContext*);
+
+ // Flag for allowing float16 precision for FP32 calculation.
+ // default: false.
+ // WARNING: This is an experimental API and subject to change.
+ bool allow_fp32_relax_to_fp16;
} TfLiteContext;
typedef struct _TfLiteRegistration {
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index 1420fbcdc6..f4d2839b1b 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -216,6 +216,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
@@ -614,6 +617,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGICAL_AND:
case BuiltinOperator_LOGICAL_NOT:
case BuiltinOperator_FLOOR_DIV:
+ case BuiltinOperator_SQUARE:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 984f8bbc98..43ec5d53b8 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -157,6 +157,34 @@ TEST_F(DelegateTest, OnlyTFLite) {
ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
}
+TEST_F(DelegateTest, MultipleInvokeCalls) {
+ // Call Invoke() multiple times on the same model.
+ AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
+ AddTfLiteMulOp({0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(1, {2, 2, 1});
+ SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+ ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
+
+ SetShape(0, {2, 2, 1});
+ SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
+ SetShape(1, {2, 2, 1});
+ SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+ ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
+}
+
TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
// Build a graph, configure the delegate and set inputs.
{
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index e3eebac4da..c6587b3d3f 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -1115,6 +1115,14 @@ class NNAPIDelegateKernel {
CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_.get(), inputs.size(), inputs.data(),
outputs.size(), outputs.data()));
+
+ // Set relaxed computation mode for fp32 if possible.
+ if (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) {
+ CHECK_NN(context,
+ ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ nn_model_.get(), context->allow_fp32_relax_to_fp16));
+ }
+
// Finalize the model
CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get()));
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 4b01aefd6a..9626c54c74 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -40,13 +40,15 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI {
public:
FloatAddOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
- ActivationFunctionType activation_type) {
+ ActivationFunctionType activation_type,
+ bool allow_fp32_relax_to_fp16 = false) {
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
CreateAddOptions(builder_, activation_type).Union());
- BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)},
+ allow_fp32_relax_to_fp16);
}
int input1() { return input1_; }
@@ -71,6 +73,19 @@ TEST(NNAPIDelegate, AddWithNoActivation) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
}
+// Do a test with the NN API using no activation.
+// The test allows computing FP32 with FP16 precision. In this particular case,
+// calculating in FP32 or FP16 should produce the same results.
+TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
+ FloatAddOpModel m(
+ {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
+ m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
+ m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 1.0, 4.0, 6.0}));
+}
+
// Do a test with the NN api with relu.
TEST(NNAPIDelegate, AddWithRelu) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD
new file mode 100644
index 0000000000..82d39c00ab
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "option_writer_generator",
+ srcs = ["option_writer_generator.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "writer_lib",
+ srcs = [
+ "enum_mapping.h",
+ "writer_lib.cc",
+ ],
+ hdrs = [
+ "writer_lib.h",
+ ],
+ data = [
+ ":option_writer_gen",
+ ],
+ textual_hdrs = ["option_writer_generated.h"],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ ],
+)
+
+cc_binary(
+ name = "writer",
+ srcs = ["writer.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "writer_lib_test",
+ size = "small",
+ srcs = ["writer_lib_test.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+genrule(
+ name = "option_writer_gen",
+ outs = ["option_writer_generated.h"],
+ cmd = "$(location :option_writer_generator) $(@)",
+ tools = [":option_writer_generator"],
+)
diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
new file mode 100644
index 0000000000..8bc464fd71
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+// TODO(aselle): Ideally extract this from the schema.
+
+namespace tflite {
+
+inline ActivationFunctionType TfLiteActivationToSchemaActivation(
+ TfLiteFusedActivation act) {
+ switch (act) {
+ case kTfLiteActNone:
+ return ActivationFunctionType_NONE;
+ case kTfLiteActRelu:
+ return ActivationFunctionType_RELU;
+ case kTfLiteActRelu1:
+ return ActivationFunctionType_RELU_N1_TO_1;
+ case kTfLiteActRelu6:
+ return ActivationFunctionType_RELU6;
+ case kTfLiteActTanh:
+ return ActivationFunctionType_TANH;
+ case kTfLiteActSignBit:
+ return ActivationFunctionType_SIGN_BIT;
+ case kTfLiteActSigmoid:
+ return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
+ }
+ return ActivationFunctionType_NONE;
+}
+
+inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingUnknown:
+ return Padding_SAME; // TODO(aselle): Consider an error.
+ case kTfLitePaddingSame:
+ return Padding_SAME;
+ case kTfLitePaddingValid:
+ return Padding_VALID;
+ }
+ return Padding_SAME; // TODO(aselle): Consider an error.
+}
+
+inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
+ switch (type) {
+ // case kTfLiteNoType: return TensorType_NONE;
+ case kTfLiteNoType:
+ return TensorType_FLOAT32; // TODO(aselle): Consider an error.
+ case kTfLiteFloat32:
+ return TensorType_FLOAT32;
+ case kTfLiteInt32:
+ return TensorType_INT32;
+ case kTfLiteUInt8:
+ return TensorType_UINT8;
+ case kTfLiteInt64:
+ return TensorType_INT64;
+ case kTfLiteString:
+ return TensorType_STRING;
+ case kTfLiteBool:
+ return TensorType_BOOL;
+ case kTfLiteInt16:
+ return TensorType_INT16;
+ case kTfLiteComplex64:
+ return TensorType_COMPLEX64;
+ }
+ // TODO(aselle): consider an error
+}
+
+inline FullyConnectedOptionsWeightsFormat
+FullyConnectedOptionsWeightsFormatToSchema(
+ TfLiteFullyConnectedWeightsFormat format) {
+ switch (format) {
+ case kTfLiteFullyConnectedWeightsFormatDefault:
+ return FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
+ return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ }
+}
+
+inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
+ switch (type) {
+ case kTfLiteLSTMFullKernel:
+ return LSTMKernelType_FULL;
+ case kTfLiteLSTMBasicKernel:
+ return LSTMKernelType_BASIC;
+ }
+}
+
+inline LSHProjectionType LSHProjectionTypeToSchema(
+ TfLiteLSHProjectionType type) {
+ switch (type) {
+ case kTfLiteLshProjectionUnknown:
+ return LSHProjectionType_UNKNOWN;
+ case kTfLiteLshProjectionSparse:
+ return LSHProjectionType_SPARSE;
+ case kTfLiteLshProjectionDense:
+ return LSHProjectionType_DENSE;
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000000..e6d5a776b3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h" // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+ "TfLitePoolParams",
+ "TfLiteDepthwiseConvParams",
+ "TfLiteSVDFParams",
+ "TfLiteRNNParams",
+ "TfLiteSequenceRNNParams",
+ "TfLiteFullyConnectedParams",
+ "TfLiteLSHProjectionParams",
+ "TfLiteSoftmaxParams",
+ "TfLiteConcatenationParams",
+ "TfLiteAddParams",
+ "TfLiteSpaceToBatchNDParams",
+ "TfLiteBatchToSpaceNDParams",
+ "TfLiteMulParams",
+ "TfLiteSubParams",
+ "TfLiteDivParams",
+ "TfLiteL2NormParams",
+ "TfLiteLocalResponseNormParams",
+ "TfLiteLSTMParams",
+ "TfLiteResizeBilinearParams",
+ "TfLitePadParams",
+ "TfLitePadV2Params",
+ "TfLiteReshapeParams",
+ "TfLiteSkipGramParams",
+ "TfLiteSpaceToDepthParams",
+ "TfLiteCastParams",
+ "TfLiteEmbeddingLookupSparseParams",
+ "TfLiteGatherParams",
+ "TfLiteTransposeParams",
+ "TfLiteReducerParams",
+ "TfLiteSplitParams",
+ "TfLiteSqueezeParams",
+ "TfLiteStridedSliceParams",
+ "TfLiteArgMaxParams",
+ "TfLiteArgMinParams",
+ "TfLiteTransposeConvParams",
+ "TfLiteSparseToDenseParams",
+ "TfLiteShapeParams",
+ "TfLiteFakeQuantParams",
+ "TfLitePackParams",
+ "TfLiteOneHotParams",
+ nullptr};
+} // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+ const char* s = in.c_str();
+ bool first = true;
+ std::string out;
+ while (*s != '\0') {
+ if (*s == '_') {
+ first = true;
+ } else if (first) {
+ out.push_back(tolower(*s));
+ first = false;
+ } else {
+ out.push_back(tolower(*s));
+ }
+ s++;
+ }
+ return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+ OpOptionData() {
+ BuildOpList();
+ BuildOptionToTypeFunctionMap();
+ BuildOpToOptionMap();
+ }
+
+ // A list of builtin operations
+ const std::vector<std::string>& ops() const { return ops_; }
+ // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+ const std::unordered_map<std::string, std::string>& op_to_option() {
+ return op_to_option_;
+ }
+ // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+ const std::unordered_map<std::string, std::string>& option_to_struct() {
+ return option_to_struct_;
+ }
+ // Maps from option to a flatbuffer type function that describes that option.
+ const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+ option_to_type_function() {
+ return option_to_type_function_;
+ }
+
+ private:
+ void BuildOpList() {
+ for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+ ++curr) {
+ if (strlen(*curr) != 0) ops_.push_back(*curr);
+ }
+ }
+
+ void BuildOptionToTypeFunctionMap() {
+ auto d = tflite::BuiltinOptionsTypeTable();
+ for (int i = 0; i < d->num_elems; i++) {
+ flatbuffers::TypeCode code = d->type_codes[i];
+ if (code.sequence_ref != -1) {
+ option_to_type_function_.insert(
+ std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+ }
+ }
+ }
+
+ void BuildOpToOptionMap() {
+ // Manually specified mappings between ops and options
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+ op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+ op_to_option_["UNPACK"] = "";
+ op_to_option_["SUM"] = "ReducerOptions";
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+ op_to_option_["MEAN"] = "ReducerOptions";
+ op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ // Manually specified mappings between ops and options (none)
+ op_to_option_["EMBEDDING_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["FLOOR"] = "";
+ op_to_option_["HASHTABLE_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["LOGISTIC"] = "";
+ op_to_option_["RELU"] = "";
+ op_to_option_["RELU_N1_TO_1"] = "";
+ op_to_option_["RELU6"] = "";
+ op_to_option_["TANH"] = "";
+ op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["PRELU"] = "";
+ op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["SIN"] = "";
+ op_to_option_["LOG"] = "";
+ op_to_option_["SQRT"] = "";
+ op_to_option_["RSQRT"] = "";
+
+ // TODO(aselle): These are undesirable hacks. Consider changing C structs
+ option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+ option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+ option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+ option_to_struct_["LocalResponseNormalizationOptions"] =
+ "TfLiteLocalResponseNormParams";
+ // Now for every op, try to find an option.
+ bool fatal = false;
+ for (auto op_name : ops_) {
+ bool found_option = false;
+ auto d = tflite::BuiltinOptionsTypeTable();
+ std::string collapsed_option_name_guess =
+ ToCollapsed(op_name) + "options";
+ // O(n^2) but not that big of n.
+ for (int i = 0; i < d->num_elems; i++) {
+ std::string option_name = d->names[i];
+ std::string collapsed_option_name = ToCollapsed(option_name);
+ if (collapsed_option_name_guess == collapsed_option_name) {
+ op_to_option_.insert(std::make_pair(op_name, option_name));
+ found_option = true;
+ break;
+ }
+ }
+ auto it = op_to_option_.find(op_name);
+ if (it == op_to_option_.end()) {
+ std::cerr << "Didn't find option for " << op_name << std::endl;
+ fatal = true;
+ } else if (!it->second.empty()) {
+ std::string option_name = it->second;
+
+ if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+ bool param_struct_found = false;
+ std::string params_guess = std::string("TfLite") + option_name;
+ size_t start = params_guess.find("Options");
+ size_t len = strlen("Options");
+ params_guess.replace(start, len, "Params");
+ for (auto* param = param_structs; *param != nullptr; param++) {
+ if (*param == params_guess) {
+ param_struct_found = true;
+ break;
+ }
+ }
+ if (!param_struct_found) {
+ std::cerr << "Failed to get param struct for option " << option_name
+ << std::endl;
+ fatal = true;
+ } else {
+ option_to_struct_.insert(std::make_pair(option_name, params_guess));
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<std::string> ops_;
+ std::unordered_map<std::string, std::string> op_to_option_;
+ std::unordered_map<std::string, std::string> option_to_struct_;
+ std::unordered_map<std::string, flatbuffers::TypeFunction>
+ option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+ const std::string& option_name,
+ const std::string& option_type,
+ const flatbuffers::TypeTable* options,
+ const std::string& struct_name) {
+ // Skip tricky ones for now
+ if (struct_name == "TfLiteResizeBilinearParams") return;
+ if (struct_name == "TfLiteSqueezeParams") return;
+ if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+ if (struct_name == "TfLiteReshapeParams") return;
+
+ fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
+ fprintf(fp,
+ " const auto* params = reinterpret_cast<const "
+ "%s*>(builtin_op_data);\n",
+ struct_name.c_str());
+
+ for (size_t i = 0; i < options->num_elems; i++) {
+ std::string elem_name = options->names[i];
+ // TODO(aselle): Irregular naming in builtins
+ if (elem_name == "fused_activation_function")
+ elem_name = "activation";
+ else if (elem_name == "stride_w")
+ elem_name = "stride_width";
+ else if (elem_name == "stride_h")
+ elem_name = "stride_height";
+ else if (elem_name == "dilation_h_factor")
+ elem_name = "dilation_height_factor";
+ else if (elem_name == "dilation_w_factor")
+ elem_name = "dilation_width_factor";
+ else if (elem_name == "new_shape")
+ elem_name = "shape";
+
+ flatbuffers::TypeCode code = options->type_codes[i];
+ auto contained_type = code.sequence_ref != -1
+ ? options->type_refs[code.sequence_ref]
+ : nullptr;
+ std::string mapper = "";
+ if (contained_type == TensorTypeTypeTable) {
+ mapper = "TfLiteTypeToSchemaType";
+ } else if (contained_type == ActivationFunctionTypeTypeTable) {
+ mapper = "TfLiteActivationToSchemaActivation";
+ } else if (contained_type == PaddingTypeTable) {
+ mapper = "TfLitePaddingToSchemaPadding";
+ } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+ mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+ } else if (contained_type == LSTMKernelTypeTypeTable) {
+ mapper = "LSTMKernelTypeToSchema";
+ } else if (contained_type == LSHProjectionTypeTypeTable) {
+ mapper = "LSHProjectionTypeToSchema";
+ }
+
+ fprintf(fp,
+ " auto val%zu = "
+ "%s(params->%s);\n",
+ i, mapper.c_str(), elem_name.c_str());
+ }
+ fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
+ for (size_t i = 0; i < options->num_elems; i++) {
+ fprintf(fp, ", val%zu", i);
+ }
+ fprintf(fp, ").Union();\n");
+ fprintf(fp, " return std::make_pair(%s, union_type);\n",
+ option_type.c_str());
+ fprintf(fp, " }\n break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+ std::unordered_set<std::string> ignores;
+ ignores.insert("CONCAT_EMBEDDINGS");
+ ignores.insert("CALL");
+
+ // Allow any op that doesn't have an options struct to be blocked
+ // together
+ for (const auto& op_name : option->ops()) {
+ auto option_it = option->op_to_option().find(op_name);
+ if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+ continue;
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ }
+ fprintf(fp,
+ " return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+
+ // Iterate over each ops
+ for (const auto& op_name : option->ops()) {
+ if (ignores.find(op_name) != ignores.end()) continue;
+ // Get to the option and struct names, continuing if not found.
+ auto option_it = option->op_to_option().find(op_name);
+ if (option_it->second.empty()) continue;
+ std::string option_name = option_it->second;
+ std::string option_type = "BuiltinOptions_" + option_name;
+ auto option_func_it = option->option_to_type_function().find(option_name);
+ if (option_func_it == option->option_to_type_function().end()) continue;
+ auto struct_name_it = option->option_to_struct().find(option_name);
+ if (struct_name_it == option->option_to_struct().end()) {
+ // If no C struct, then it better have no arguments.
+ auto type_info = option_func_it->second();
+ if (type_info->num_elems != 0) {
+ // We have non-zero arguments in the schema, this means there
+ // should be a struct.
+ fprintf(stderr,
+ "Op %s uses option struct %s which has no builtin struct\n",
+ op_name.c_str(), option_name.c_str());
+ exit(1);
+ }
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
+ option_type.c_str(), option_name.c_str());
+ } else {
+ // If C struct, then we need to assign all properties
+ auto struct_name = struct_name_it->second;
+ GenerateImportForOp(fp, op_name, option_name, option_type,
+ option_func_it->second(), struct_name);
+ }
+ }
+ // TODO(aselle): Handle unhandled cases more gracefully.
+ fprintf(fp,
+ "default: return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+}
+
+} // namespace tflite
+
+int main(int argc, char* argv[]) {
+ tflite::OpOptionData option;
+ if (argc != 2) {
+ fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+ return 1;
+ }
+ FILE* fp = fopen(argv[1], "w");
+ tflite::GenerateImport(&option, fp);
+ fclose(fp);
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc
new file mode 100644
index 0000000000..20ede214fb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Just does a read/write loop of tflite file format using the interpreter as
+// an intermediate.
+//
+// Usage:
+// writer <input tflite> <output tflite>
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+int main(int argc, char* argv[]) {
+ if (argc != 3) {
+ fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
+ return 1;
+ }
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(argv[1]);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
+ tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
+ tflite::InterpreterWriter writer(interpreter.get());
+ writer.Write(argv[2]);
+
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
new file mode 100644
index 0000000000..555a9cc4b0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -0,0 +1,287 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <cstdlib>
+#include <cstring>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+template <class T>
+using Offset = flatbuffers::Offset<T>;
+template <class T>
+using Vector = flatbuffers::Vector<T>;
+using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
+
+std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
+ FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
+ switch (op) {
+#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
+ }
+ return std::make_pair(BuiltinOptions_NONE, Offset<void>());
+}
+
+template <class T_OUTPUT, class T_INPUT>
+Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
+ const T_INPUT& v) {
+ std::vector<T_OUTPUT> inputs(v.begin(), v.end());
+ return fbb->template CreateVector<T_OUTPUT>(inputs);
+}
+
+Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Operator>> operators;
+
+ std::vector<int> operator_to_opcode;
+ // TODO(aselle): Augment this once we put execution plan in schema.
+ operator_to_opcode.resize(interpreter_->nodes_size(), -1);
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteRegistration* registration = &node_and_registration->second;
+ if (!registration->custom_name) {
+ operator_to_opcode[op_index] =
+ GetOpCodeForBuiltin(registration->builtin_code);
+ } else {
+ operator_to_opcode[op_index] =
+ GetOpCodeForCustom(registration->custom_name);
+ }
+ }
+ // second pass serialize operators
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ Offset<void> builtin_options;
+ BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
+ // Custom data
+ // TODO(aselle): Custom options format is not known by default. Just assume
+ // for now.
+ auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
+ Offset<Vector<uint8_t>> custom_options = 0;
+
+ if (!registration.custom_name) {
+ // builtin
+ auto builtin_options_and_type = CreateBuiltinUnion(
+ fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
+ node.builtin_data);
+ builtin_options = builtin_options_and_type.second;
+ builtin_options_type = builtin_options_and_type.first;
+ } else {
+ auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
+ if (custom_writer != custom_op_to_writer_.end() &&
+ custom_writer->second) {
+ // delegate to custom writer if it exists
+ custom_writer->second(fbb, interpreter_, op_index, &custom_options,
+ &custom_options_format);
+ } else {
+ // use the custom data as fact
+ custom_options = fbb->CreateVector(
+ reinterpret_cast<const uint8_t*>(node.custom_initial_data),
+ node.custom_initial_data_size);
+ }
+ }
+
+ int opcode_index = operator_to_opcode[op_index];
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
+ auto inputs = ExportVector<int32_t>(fbb, written_inputs);
+ auto outputs = ExportVector<int32_t>(fbb, written_outputs);
+ operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
+ builtin_options_type, builtin_options,
+ custom_options, custom_options_format));
+ }
+
+ return fbb->template CreateVector<Offset<Operator>>(operators);
+}
+
+Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
+ FlatBufferBuilder* fbb) {
+ // Initialized to -1.
+ // A value of -1 means this tensor will not be exported.
+ tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
+
+ std::vector<Offset<Tensor>> tensors;
+
+ // Make a map from tensor index to whether the tensor is a temporary.
+ std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
+ for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ for (auto tensor_index :
+ TfLiteIntArrayView(node_and_registration->first.temporaries))
+ tensor_is_temporary[tensor_index] = true;
+ }
+
+ // Now we need to remap all used tensor indices
+ int curr_output_index = 0;
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ tensor_index++) {
+ // Temporary tensors and unused tensors will not be written.
+ if (!tensor_is_temporary[tensor_index] &&
+ unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
+ tensor_to_written_tensor_[tensor_index] = curr_output_index++;
+ }
+ }
+
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ ++tensor_index) {
+ // Tensor not exported.
+ if (tensor_to_written_tensor_[tensor_index] == -1) continue;
+
+ if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
+ // We only need to convert non temporaries
+ if (tensor->allocation_type != kTfLiteArenaRw &&
+ tensor->allocation_type != kTfLiteMmapRo &&
+ tensor->allocation_type != kTfLiteArenaRwPersistent)
+ continue;
+ // Allocate a buffer index
+ int buffer_index = 0; // This is null
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ buffer_index = buffers_.size();
+ buffers_.push_back(std::make_pair(
+ reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
+ }
+ // Primitive type.
+ TensorType type = TfLiteTypeToSchemaType(tensor->type);
+ // Handle quantization
+ const Offset<Vector<float>> null_array;
+ Offset<Vector<float>> scale_array;
+ Offset<Vector<int64_t>> zero_point_array;
+ if (tensor->params.scale != 0.f) {
+ // We have quantization, make a single arugment array (multi channel
+ // quant needs updating here).
+ scale_array = fbb->CreateVector<float>({tensor->params.scale});
+ zero_point_array =
+ fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ }
+ Offset<QuantizationParameters> quantization_params =
+ CreateQuantizationParameters(*fbb, null_array, null_array,
+ scale_array, zero_point_array);
+ // Shape
+ TfLiteIntArrayView shape_view(tensor->dims);
+ std::vector<int> shape =
+ std::vector<int>(shape_view.begin(), shape_view.end());
+
+ tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
+ type, buffer_index,
+ fbb->CreateString(tensor->name),
+ quantization_params, tensor->is_variable));
+ }
+ }
+ return fbb->template CreateVector<Offset<Tensor>>(tensors);
+}
+
+Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (auto buffer : buffers_) {
+ auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+ buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+ }
+ return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
+}
+
+Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<OperatorCode>> codes;
+ for (auto it : opcodes_) {
+ const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+ codes.push_back(CreateOperatorCodeDirect(
+ *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+ }
+ return fbb->template CreateVector<Offset<OperatorCode>>(codes);
+}
+
+template <class T>
+std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
+ const T& input) {
+ std::vector<int> output;
+ output.reserve(input.size());
+ for (int x : input) {
+ if (tensor_to_written_tensor_[x] != -1) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
+ }
+ return output;
+}
+
+TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+ size_t* size) {
+ if (!out || !size) return kTfLiteError;
+ FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ std::vector<Offset<SubGraph>> subgraphs_as_vector;
+ { // subgraph specific stuff
+ auto tensors = ExportTensors(&builder);
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(interpreter_->inputs());
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(interpreter_->outputs());
+ auto inputs = ExportVector<int32_t>(&builder, written_inputs);
+ auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+
+ auto ops = ExportOperators(&builder);
+ subgraphs_as_vector.push_back(
+ CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
+ }
+ Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
+
+ auto description = builder.CreateString("Exported from Interpreter.");
+
+ auto op_codes = CreateOpCodeTable(&builder);
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs_as_vector),
+ description, buffers);
+ ::tflite::FinishModelBuffer(builder, model);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ *size = builder.GetSize();
+ (*out).reset(new uint8_t[*size]);
+ memcpy(out->get(), buffer, *size);
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
+ std::unique_ptr<uint8_t[]> buffer;
+ size_t size;
+ TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+
+ FILE* fp = fopen(filename.c_str(), "wb");
+ if (!fp) return kTfLiteError;
+
+ if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
+ if (fclose(fp)) return kTfLiteError;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::RegisterCustomWriter(
+ const std::string& custom_name, CustomWriter custom_writer) {
+ if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
+ return kTfLiteError;
+ }
+ custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
new file mode 100644
index 0000000000..a5f14697cf
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -0,0 +1,131 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
+//
+// Usage:
+// From command line:
+// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
+// -- foo.tflite foo.out.tflite
+//
+// From C++
+// std::unique_ptr<Interpreter> interpreter;
+// // Build Interpreter however
+// // ... <omitted>
+// InterpreterWriter(interpreter.get()).Write("output.tflite");
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#include <iostream>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
+// file format.
+class InterpreterWriter {
+ public:
+ typedef flatbuffers::Offset<Operator> (*CustomWriter)(
+ flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
+ int node_index,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
+ CustomOptionsFormat* custom_options_format);
+
+ // Construct an interpreter writer for the specified `interpreter`. Then,
+ // a uses .Write() or .GetBuffer(...) to extract the data.
+ explicit InterpreterWriter(Interpreter* interpreter)
+ : interpreter_(interpreter) {
+ buffers_.push_back(std::make_pair(nullptr, 0));
+ }
+
+ // Get a buffer and size of a serialized flatbuffer.
+ TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+ // Write the serialized flatbuffer to the prescribed `filename`.
+ TfLiteStatus Write(const std::string& filename);
+ // Registers a custom writer for a custom op. The customization allows the
+ // caller to change the custom data.
+ TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
+ CustomWriter custom_writer);
+ // Tensors that are unused and shouldn't be written.
+ void SetUnusedTensors(const std::set<int>& unused_tensors) {
+ unused_tensors_ = unused_tensors;
+ }
+
+ private:
+ template <class T>
+ using Offset = flatbuffers::Offset<T>;
+ template <class T_OUTPUT, class T_INPUT>
+ Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
+ flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
+ Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+ flatbuffers::FlatBufferBuilder* fbb);
+
+ template <class T>
+ std::vector<int> RemapTensorIndicesToWritten(const T& input);
+
+ int GetOpCodeForBuiltin(int builtin_op_index) {
+ // auto it = builtin_op_to_opcode_.find(builtin_op_index);
+ std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
+ builtin_op_to_opcode_.insert(
+ std::make_pair(builtin_op_index, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({builtin_op_index, ""});
+ }
+ return result.first->second;
+ }
+
+ int GetOpCodeForCustom(const std::string& custom_name) {
+ std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
+ custom_op_to_opcode_.insert(
+ std::make_pair(custom_name, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+ }
+ return result.first->second;
+ }
+
+ // The interpreter we are writing
+ Interpreter* interpreter_;
+ // Keep track of byte buffers
+ std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+ // List of op codes and mappings from builtin or custom op to opcode
+ struct OpCode {
+ int builtin;
+ std::string custom;
+ };
+ std::set<int> unused_tensors_;
+ // For every tensor index in the interpreter, the index in the written.
+ // This is different due to temporary and unused tensors not being written.
+ std::vector<int> tensor_to_written_tensor_;
+ // List of used opcodes
+ std::vector<OpCode> opcodes_;
+ std::unordered_map<int, int> builtin_op_to_opcode_;
+ std::unordered_map<std::string, int> custom_op_to_opcode_;
+ std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
new file mode 100644
index 0000000000..49194a76c8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+// Make an interpreter that has no tensors and no nodes
+// TODO(b/113731921): add more tests.
+TEST(Writer, BasicTest) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+ float foo[] = {1, 2, 3};
+ interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetTensorParametersReadOnly(
+ 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
+ reinterpret_cast<char*>(foo), sizeof(foo));
+ interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({2});
+ const char* initial_data = "";
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ TfLiteAddParams* builtin_data =
+ reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+ builtin_data->activation = kTfLiteActNone;
+ const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+ interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+ reinterpret_cast<void*>(builtin_data), reg);
+
+ InterpreterWriter writer(&interpreter);
+ writer.Write("/tmp/test.tflite");
+ std::unique_ptr<FlatBufferModel> model =
+ FlatBufferModel::BuildFromFile("/tmp/test.tflite");
+ InterpreterBuilder builder(*model, resolver);
+ std::unique_ptr<Interpreter> new_interpreter;
+ builder(&new_interpreter);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index 9119e49117..b3f21e21ac 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -5,7 +5,8 @@ landing_page:
rows:
- heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
items:
- - description: >
+ - classname: devsite-landing-row-50
+ description: >
TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
embedded devices. It enables on-device machine learning inference with
low latency and a small binary size. TensorFlow Lite also supports
@@ -33,7 +34,7 @@ landing_page:
icon_name: chevron_right
foreground: theme
background: grey
- - code_block: |
+ code_block: |
<pre class = "prettyprint">
$ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 88f6cda420..a4267eee4c 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -7,65 +7,64 @@ Model Name | Paper_Model_Files^
--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
-NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 74.2% | 91.7% | 261 ms | 389 ms
-NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.8% | 96.2% | 6697 ms | 7940 ms
-ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms
-ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms
-Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 78.2% | 94.0% | 1433 ms | 1522 ms
-Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.4% | 95.2% | 2986 ms | 3139 ms
-Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.8% | 94.1% | 2731 ms | 2926 ms
-Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.6% | 66.6% | 6.2 ms | 13.0 ms
-Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.7% | 70.6% | 8.6 ms | 19.5 ms
-Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.5% | 72.4% | 12.1 ms | 27.8 ms
-Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 50.0% | 74.4% | 16.2 ms | 37.3 ms
-Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.5% | 79.5% | 18.1 ms | 29.9 ms
-Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.3% | 82.1% | 26.8 ms | 45.9 ms
-Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 62.0% | 83.7% | 35.6 ms | 65.3 ms
-Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.5% | 85.0% | 47.6 ms | 164.2 ms
-Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.3% | 84.1% | 34.6 ms | 48.7 ms
-Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.5% | 86.1% | 51.3 ms | 75.2 ms
-Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.4% | 87.4% | 71.7 ms | 107.0 ms
-Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.6% | 88.3% | 95.7 ms | 143.4 ms
-Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.5% | 85.9% | 57.4 ms | 76.8 ms
-Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.3% | 87.8% | 86.0 ms | 117.7 ms
-Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.2% | 89.3% | 118.6 ms | 167.3 ms
-Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.3% | 90.1% | 160.1 ms | 224.3 ms
-Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.9% | 90.1% | 117 ms |
+NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms
+NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms
+ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms
+Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms
+Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms
+Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms
+Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms
+Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
+Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms
+Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms
+Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms |
^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
^^ The performance numbers are generated in the benchmark on Pixel-2 using
single thread large core.
-^^ Accuracy numbers were computed using the [TFLite accuracy tool](../tools/accuracy/ilsvrc)
-after excluding blacklisted images.
+^^ Accuracy numbers were computed using the
+[TFLite accuracy tool](../tools/accuracy/ilsvrc) .
## Image classification (Quantized Models)
Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.8% | 64.8% | 3.7 ms
-Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.0% | 68.4% | 5.5 ms
-Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms
-Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.5% | 73.1% | 10.4 ms
-Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 55.2% | 78.4% | 8.8 ms
-Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.5% | 80.7% | 13.0 ms
-Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.2% | 82.3% | 18.3 ms
-Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.5% | 83.5% | 24.7 ms
-Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 56.2% | 79.4% | 16.2 ms
-Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.7% | 83.9% | 24.3 ms
-Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.4% | 86.4% | 33.8 ms
-Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.2% | 87.0% | 45.4 ms
-Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.6% | 84.3% | 24.9 ms
-Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.9% | 37.4 ms
-Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.4% | 88.3% | 51.9 ms
-Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.2% | 89.1% | 70.2 ms
-Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 71.1% | 90.1% | 80.3 ms
-Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.6% | 637 ms
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.7% | 70.8% | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 72.8% | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.1% | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.2% | 80.5% | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.9% | 82.1% | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.2% | 83.2% | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.9% | 79.1% | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.4% | 83.7% | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.2% | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.9% | 86.9% | 45.4 ms
+Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.3% | 84.1% | 24.9 ms
+Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.9% | 86.7% | 37.4 ms
+Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.1% | 88.1% | 51.9 ms
+Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.0% | 89.0% | 70.2 ms
+Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 70.8% | 89.9% | 80.3 ms
+Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms
## Other models
-Lite FlatBuffer ----------------------- | :----------------: Smart Reply 1.0
-Android |
+Model | TF Lite FlatBuffer
+----------------------- | :----------------:
[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html),
[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 3f8f4d198f..2657bcd42b 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -123,6 +123,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
+ context_.allow_fp32_relax_to_fp16 = false;
context_.recommended_num_threads = -1;
context_.GetExternalContext = GetExternalContext;
context_.SetExternalContext = SetExternalContext;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index f0cd178c19..aa2bc4def6 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -336,6 +336,19 @@ class Interpreter {
// Set the number of threads available to the interpreter.
void SetNumThreads(int num_threads);
+ // Allow float16 precision for FP32 calculation when possible.
+ // default: not allow.
+ // WARNING: This is an experimental API and subject to change.
+ void SetAllowFp16PrecisionForFp32(bool allow) {
+ context_.allow_fp32_relax_to_fp16 = allow;
+ }
+
+ // Get the half precision flag.
+ // WARNING: This is an experimental API and subject to change.
+ bool GetAllowFp16PrecisionForFp32() const {
+ return context_.allow_fp32_relax_to_fp16;
+ }
+
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index e3cea19e16..6a3f0651d0 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -20,9 +20,6 @@ code to merge.
- Make sure to install the latest version of Bazel. Some distributions
ship with Bazel 0.5.4, which is too old.
- Bazel requires Android Build Tools `26.0.1` or higher.
- - **Bazel is incompatible with NDK revisions 15 and above,** with revision
- 16 being a compile-breaking change. [Download an older version manually
- instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites)
- You also need to install the Android Support Repository, available
through Android Studio under `Android SDK Manager -> SDK Tools ->
Android Support Repository`.
@@ -37,8 +34,7 @@ code to merge.
- Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
you have installed.
- By default, Android Studio will install the SDK to `~/Android/Sdk` and
- the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual
- download until Bazel supports NDK 16. See bullet points under (1)).
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build the app with Bazel. The demo needs C++11:
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 411615aa62..f7e6f083ed 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -177,6 +177,30 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
}));
}
+TEST_P(ConvolutionOpTest, InputAndFilterSameWidthHeight) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {1, 2, 4, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // row = 1
+ -1, -1, 1, 1, // row = 2
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 34}));
+}
+
TEST_P(ConvolutionOpTest, PointwiseFloat32) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_FLOAT32, {1, 1, 1, 2}},
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 347515f289..3e1ce60113 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -177,8 +182,19 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, float, float, float*,
+ const Dims<4>&);
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -188,7 +204,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(filter), GetTensorDims(filter),
GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_activation_min, output_activation_max,
GetTensorData<float>(output), GetTensorDims(output));
}
@@ -204,9 +221,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, int, int, int32, int32, int,
+ int32, int32, uint8*, const Dims<4>&);
+
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -216,7 +244,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_offset, data->output_multiplier,
data->output_shift, data->output_activation_min,
data->output_activation_max, GetTensorData<uint8_t>(output),
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index c00cafb9fb..2af26ab80a 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -30,7 +30,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
// stride values.
BaseDepthwiseConvolutionOpModel(const TensorData& input,
const TensorData& filter,
- const TensorData& output) {
+ const TensorData& output,
+ int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -56,7 +57,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
- ActivationFunctionType_NONE)
+ ActivationFunctionType_NONE,
+ dilation_factor, dilation_factor)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
@@ -110,6 +112,58 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
+TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ DepthwiseConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -207,6 +261,64 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ QuantizedDepthwiseConvolutionOpModel m(
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 04995d70dd..8c624b3208 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
@@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
return &r;
}
+TfLiteRegistration* Register_SQUARE() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SquareEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index b9d7d73c52..5dd89a0eae 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Square) {
+ ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2d5d..70810ca784 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -1067,6 +1067,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
+ // be implemented.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 3fd00c8930..f707279600 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1964,6 +1964,30 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise is not
+ // supported yet.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 5fb31889fe..59f0e3c927 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -113,8 +113,8 @@ class EigenTensorConvFunctor {
filter_width * filter_height * input_depth;
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
- EigenMatrix output(output_data, 1, filter_count);
- ConstEigenMatrix input(input_data, 1, k);
+ EigenMatrix output(output_data, input_batches, filter_count);
+ ConstEigenMatrix input(input_data, input_batches, k);
ConstEigenMatrix filter(filter_data, k, filter_count);
MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
filter, dim_pair);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2c8e8f90e3..659a65a8ea 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -200,6 +200,8 @@ struct TTypes {
UnalignedConstMatrix;
};
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
template <typename Scalar, int N>
@@ -212,6 +214,18 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const RuntimeShape& shape,
+ int rows) {
+ const int flatsize = shape.FlatSize();
+ TFLITE_DCHECK_EQ(flatsize % rows, 0);
+ const int cols = flatsize / rows;
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
// This is like the template-parameter version, except that the power-of-two is
// passed as a function parameter. The template version is to be preferred,
// since some target hardware optimizations depend on the range of the exponent.
@@ -260,16 +274,16 @@ inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
return true;
}
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
+inline void AddBiasAndEvalActivationFunction(float output_activation_min,
+ float output_activation_max,
+ const RuntimeShape& bias_shape,
+ const float* bias_data,
+ const RuntimeShape& array_shape,
+ float* array_data) {
#ifdef USE_NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
float* array_ptr = array_data;
float* array_end_ptr = array_ptr + array_size;
@@ -319,8 +333,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
}
#else // not NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
for (int array_offset = 0; array_offset < array_size;
array_offset += bias_size) {
@@ -333,6 +347,19 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
// Note: This to be converted to RuntimeShapes along with Conv.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
@@ -380,21 +407,24 @@ inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
// to a matrix*vector product. LSTM cells contain a fully-connected node;
// when quantized, this becomes a special type of GEMV operation where
// the output is 16bit-quantized, thus needs its own special path.
-inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data,
- const Dims<4>& weights_dims,
- uint8 weights_zero_point, const int32* bias_data,
- const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data,
- const Dims<4>& output_dims) {
+inline void GEMVForLstmCell(const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& weights_shape,
+ const uint8* weights_data, uint8 weights_zero_point,
+ const RuntimeShape& bias_shape,
+ const int32* bias_data, int32 accum_multiplier,
+ int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -567,18 +597,21 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
#ifdef GEMMLOWP_NEON
inline void GEMVForLstmCellWithSymmetricRange(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data, const Dims<4>& output_dims) {
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& weights_shape, const uint8* weights_data,
+ const RuntimeShape& bias_shape, const int32* bias_data,
+ int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -854,14 +887,16 @@ inline void GEMVForLstmCellWithSymmetricRange(
}
#endif
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnected");
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+
// TODO(b/62193649): this convoluted shape computation (determining
// input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
// is because the current --variable_batch hack consists in overwriting the
@@ -870,18 +905,38 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
// When that is fixed, this should become:
// const auto input_matrix_map =
// MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const int input_rows = ArraySize(weights_dims, 0);
+ const int dims_count = weights_shape.DimensionsCount();
+ const int input_rows = weights_shape.Dims(dims_count - 1);
const auto input_matrix_map =
- MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+ MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
const auto filter_matrix_map =
- MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+ MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -899,20 +954,23 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ int32 input_offset, const RuntimeShape& filter_shape,
+ const uint8* filter_data, int32 filter_offset,
+ const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ int32 output_activation_max, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
static constexpr int kPeel = 4;
const bool shift_left = (output_shift <= 0);
for (int k = 0; k < input_size; k += 64) {
@@ -1083,42 +1141,47 @@ struct GemmlowpOutputPipeline {
}
};
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
#ifdef USE_NEON
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
if (batches == 1 && !(output_size % 4)) {
return FullyConnectedAsGEMV(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
+ input_shape, input_data, input_offset, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_shape, output_data);
}
#endif // USE_NEON
- const int filter_rows = filter_dims.sizes[1];
- const int filter_cols = filter_dims.sizes[0];
- TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
- const int output_rows = output_dims.sizes[0];
+ const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
+ const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
+ TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
+ const int output_rows = output_shape.Dims(output_dim_count - 1);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, output_rows, filter_cols, filter_cols);
@@ -1135,30 +1198,65 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
inline void FullyConnected(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
- int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data_int32, const RuntimeShape& output_shape,
+ int16* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
(void)gemm_context; // only used in properly optimized code.
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
// Implementation of the fully connected node suited to the inside of an LSTM
// cell. The operands are 8-bit integers, the accumulators are internally
@@ -1169,17 +1267,17 @@ inline void FullyConnected(
if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
- GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
- filter_dims, bias_data_int32, bias_dims,
- output_multiplier, -output_shift,
- output_data, output_dims);
+ GEMVForLstmCellWithSymmetricRange(
+ input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data_int32, output_multiplier, -output_shift, output_shape,
+ output_data);
return;
}
if (!(output_depth % 4) && !(accum_depth % 8)) {
- GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
- filter_offset, bias_data_int32, bias_dims,
- output_multiplier, -output_shift, output_data,
- output_dims);
+ GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data_int32,
+ output_multiplier, -output_shift, output_shape,
+ output_data);
return;
}
}
@@ -1213,6 +1311,31 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
@@ -1555,26 +1678,34 @@ struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
};
inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
(void)gemm_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
TFLITE_DCHECK((accum_depth % 16) == 0);
TFLITE_DCHECK((output_depth % 4) == 0);
// Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
@@ -1671,13 +1802,39 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
+ int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height,
+ int pad_width, int pad_height,
+ int in_width, int in_height,
+ int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data,
+ T* conv_buffer_data, uint8 zero_byte) {
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
// This chunk of code reshapes all the inputs corresponding to
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
const int kwidth_times_indepth = kwidth * in_depth;
@@ -1699,7 +1856,7 @@ inline void ExtractPatchIntoBufferColumn(
const int output_row_offset = (buffer_id * single_buffer_length);
int out_offset =
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
- int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+ int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
// Express all of the calculations as padding around the input patch.
const int top_padding = h_offset;
@@ -1713,7 +1870,7 @@ inline void ExtractPatchIntoBufferColumn(
// patch that are off the edge of the input image.
if (top_padding > 0) {
const int top_row_elements = (top_padding * kwidth * in_depth);
- memset(conv_buffer_data + output_row_offset, byte_zero,
+ memset(conv_buffer_data + output_row_offset, zero_byte,
(top_row_elements * sizeof(T)));
}
@@ -1730,14 +1887,14 @@ inline void ExtractPatchIntoBufferColumn(
for (int ih = ih_start; ih < ih_end; ++ih) {
if (left_padding > 0) {
const int left_start = (out_offset - (left_padding * in_depth));
- memset(conv_buffer_data + left_start, byte_zero,
+ memset(conv_buffer_data + left_start, zero_byte,
(left_padding * in_depth * sizeof(T)));
}
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
single_row_num * sizeof(T));
if (right_padding > 0) {
const int right_start = (out_offset + single_row_num);
- memset(conv_buffer_data + right_start, byte_zero,
+ memset(conv_buffer_data + right_start, zero_byte,
(right_padding * in_depth * sizeof(T)));
}
out_offset += kwidth_times_indepth;
@@ -1752,61 +1909,64 @@ inline void ExtractPatchIntoBufferColumn(
const int bottom_start =
output_row_offset +
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
- memset(conv_buffer_data + bottom_start, byte_zero,
+ memset(conv_buffer_data + bottom_start, zero_byte,
(bottom_row_elements * sizeof(T)));
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 byte_zero,
- T* im2col_data) {
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
// For dilated convolution, the input pixels are not contiguous therefore we
// can't use the same opitimizations as Im2Col(). Though note this code would
// work fine for the non-dilated case too (though likely a bit slower).
gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 3);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 0);
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Loop through the output rows (B x H x W)
for (int batch = 0; batch < batches; ++batch) {
@@ -1814,7 +1974,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
for (int out_x = 0; out_x < output_width; ++out_x) {
// Each im2col row is an output pixel. Arrange the input data in this
// row in an order we can conveniently multiply with the filter data.
- int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
// Loop through all the pixels of the filter (Kh x Kw)
@@ -1825,25 +1985,25 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
// Loop through all the filter pixels in this row.
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
- int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims, col_offset, row_offset, 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
if ((in_x >= 0) && (in_x < input_width)) {
// Filter pixel is within the input, copy the input data.
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
memcpy(dst, src, input_depth * sizeof(T));
} else {
// Filter pixel is outside the input, zero it out.
- memset(dst, byte_zero, input_depth * sizeof(T));
+ memset(dst, zero_byte, input_depth * sizeof(T));
}
}
} else {
// Filter row is outside the input, zero out the entire filter row.
- int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
- T* dst =
- im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
- memset(dst, byte_zero, filter_width * input_depth * sizeof(T));
+ int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
}
}
}
@@ -1851,21 +2011,49 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 byte_zero, T* output_data,
- const Dims<4>& output_dims) {
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Im2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
int buffer_id = 0;
// Loop over the output nodes.
@@ -1873,93 +2061,155 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
ExtractPatchIntoBufferColumn(
- input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
pad_width, pad_height, input_width, input_height, input_depth,
- output_depth, buffer_id, input_data, output_data, byte_zero);
+ output_depth, buffer_id, input_data, output_data, zero_byte);
++buffer_id;
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
// legacy, for compatibility with old checked-in code
template <typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
(void)im2col_data;
- (void)im2col_dims;
+ (void)im2col_shape;
gemmlowp::ScopedProfilingLabel label("Conv");
// NB: static_cast<float>(0x00000000h) == 0.0f
const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_dilated_im2col) {
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, float_zero_byte,
- im2col_data);
+ DilatedIm2col(params, float_zero_byte, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, float_zero_byte,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
// TODO(aselle): We need to make sure to not send im2col if it is not
// needed.
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
- const int8_t* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* scaling_factors_ptr,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- int8_t* im2col_data, const Dims<4>& im2col_dims) {
- const int batch_size = input_dims.sizes[3];
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
+ const RuntimeShape& input_shape,
+ const int8_t* input_data,
+ const RuntimeShape& filter_shape,
+ const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data,
+ const RuntimeShape& im2col_shape, int8_t* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
+
+ const int batch_size = input_shape.Dims(0);
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const int8_t* gemm_input_data = nullptr;
int num_input;
@@ -1970,25 +2220,22 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK(im2col_data);
// symmetric quantization assumes zero point of 0.
const int input_zero_point = 0;
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
- im2col_dims.sizes[2] * im2col_dims.sizes[3];
+ num_input = im2col_shape.FlatSize();
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- num_input = input_dims.sizes[0] * input_dims.sizes[1] *
- input_dims.sizes[2] * input_dims.sizes[3];
+ num_input = input_shape.FlatSize();
}
// Flatten 4D matrices into 2D matrices for matrix multiplication.
// Flatten so that each filter has its own row.
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int filter_rows = filter_shape.Dims(0);
+ const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
// In MatrixBatchVectorMultiplyAccumulate, each output value is the
// dot product of one row of the first matrix with one row of the second
@@ -1998,15 +2245,14 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
const int gemm_input_cols = filter_cols;
const int gemm_input_rows = num_input / gemm_input_cols;
- const int output_cols = output_dims.sizes[0];
- const int output_rows =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ const int output_cols = output_shape.Dims(3);
+ const int output_rows = FlatSizeSkipDim(output_shape, 3);
TFLITE_DCHECK_EQ(output_cols, filter_rows);
TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(3), output_cols);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(2), 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(1), 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(0), 1);
// MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
// input matrix has its own scale factor. This code duplicates the scale
@@ -2023,11 +2269,39 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
/*result_stride=*/1);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -2045,6 +2319,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2061,6 +2336,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2074,27 +2350,33 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
output_dims, im2col_data, im2col_dims);
}
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
-
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
const uint8* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
@@ -2104,53 +2386,47 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, input_zero_point,
- im2col_data);
+ DilatedIm2col(params, input_zero_point, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
- const int gemm_input_rows = gemm_input_dims->sizes[0];
+ const int gemm_input_rows = gemm_input_shape->Dims(3);
// Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
// The root cause has not yet been identified though. Same applies below for
// the other calls commented out. This is a partial rollback of cl/196819423.
- // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
- const int gemm_input_cols = gemm_input_dims->sizes[1] *
- gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- const int filter_rows = filter_dims.sizes[3];
+ // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
+ const int gemm_input_cols = gemm_input_shape->Dims(0) *
+ gemm_input_shape->Dims(1) *
+ gemm_input_shape->Dims(2);
+ const int filter_rows = filter_shape.Dims(0);
// See b/79927784.
- // const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
- const int output_rows = output_dims.sizes[0];
+ filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
+ const int output_rows = output_shape.Dims(3);
// See b/79927784.
- // const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ // const int output_cols = FlatSizeSkipDim(output_shape, 3);
const int output_cols =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, filter_rows, filter_cols);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
@@ -2166,6 +2442,43 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
@@ -2184,6 +2497,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2213,6 +2527,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2236,13 +2551,14 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
// legacy, for compatibility with old checked-in code
@@ -2266,6 +2582,7 @@ void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
@@ -2320,9 +2637,9 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -2361,9 +2678,9 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_depth = output_shape.Dims(3);
@@ -3191,7 +3508,7 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -3458,10 +3775,11 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
bool gemm_already_performed = false;
#ifdef GEMMLOWP_NEON
if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
- GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims,
- weights_data_uint8, weights_dims, weights_zero_point,
- bias_data_int32, bias_dims, accum_multiplier, accum_shift,
- activ_temp_data_int16, activ_temp_dims);
+ GEMVForLstmCell(DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8,
+ weights_zero_point, DimsToShape(bias_dims), bias_data_int32,
+ accum_multiplier, accum_shift, DimsToShape(activ_temp_dims),
+ activ_temp_data_int16);
gemm_already_performed = true;
}
#endif
@@ -5442,9 +5760,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5491,9 +5809,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5552,9 +5870,9 @@ inline void BatchToSpaceND(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_width = output_shape.Dims(2);
@@ -5638,8 +5956,10 @@ inline void PadImpl(const tflite::PadParams& op_params,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
gemmlowp::ScopedProfilingLabel label("Pad");
- RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
- RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
@@ -5771,7 +6091,7 @@ inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Slice");
- RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -5832,58 +6152,45 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 3); // output_depth
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Build the im2col matrix by looping through all the input pixels,
// computing their influence on the output, rather than looping through all
// the output pixels. We therefore must initialize the im2col array to zero.
// This is potentially inefficient because we subsequently overwrite bytes
// set here. However, in practice memset is very fast and costs negligible.
- memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+ memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
// Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
@@ -5903,11 +6210,11 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
if ((out_x >= 0) && (out_x < output_width)) {
// Copy the input elements of this pixel
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims,
- Offset(col_dims, 0, filter_x, filter_y, 0),
- Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
memcpy(dst, src, input_depth * sizeof(T));
}
}
@@ -5918,31 +6225,71 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
- TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, pad_width, pad_height, output_dims, 0,
- im2col_data);
+ TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+ output_shape, im2col_data);
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
+ MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5000..bb5d590775 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -25,8 +25,9 @@ namespace reference_ops {
inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
@@ -52,8 +53,9 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -81,6 +83,20 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index d57739279f..5e3e8997fc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -30,8 +30,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
int32 output_offset, int32 output_multiplier,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
@@ -58,8 +59,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -90,6 +92,24 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 0abacf85e1..66f18ec195 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -419,9 +419,9 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -472,9 +472,9 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -1117,7 +1117,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1158,7 +1158,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1200,7 +1200,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1350,7 +1350,7 @@ void BroadcastMul4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -1483,7 +1483,7 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
// The input shapes are extended as part of NdArrayDesc initialization.
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
@@ -1579,7 +1579,7 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -1713,7 +1713,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1754,7 +1754,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1818,7 +1818,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1858,7 +1858,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1897,7 +1897,7 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -3543,11 +3543,11 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_size_shape =
+ const RuntimeShape output_size_shape =
RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -3606,9 +3606,9 @@ inline void SpaceToBatchND(
const RuntimeShape& unextended_output_shape, T* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int depth = input1_shape.Dims(3);
@@ -3663,9 +3663,9 @@ inline void BatchToSpaceND(
const RuntimeShape& unextended_output_shape, T* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_width = output_shape.Dims(2);
@@ -3719,8 +3719,10 @@ inline void PadImpl(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
- RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
- RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
@@ -3817,9 +3819,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
// Reverse and pad to 4 dimensions because that is what the runtime code
@@ -3915,7 +3917,7 @@ template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
- RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -4141,9 +4143,9 @@ inline void Mean(const tflite::MeanParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_batch = output_shape.Dims(0);
@@ -4290,7 +4292,7 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4577,7 +4579,7 @@ inline void BroadcastComparison4DSlowImpl(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4636,7 +4638,7 @@ inline void BroadcastComparison4DSlowWithScaling(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4877,16 +4879,22 @@ inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
+inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
- const RuntimeShape& input2_shape,
+ const RuntimeShape& unextended_input2_shape,
const T* input2_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4923,7 +4931,7 @@ inline void BroadcastLogical4DSlow(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4962,7 +4970,7 @@ inline void BroadcastBinaryFunction4DSlow(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index c4c7cf3842..023707d466 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -26,8 +26,8 @@ enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
enum class PaddingType : uint8 { kNone, kSame, kValid };
struct PaddingValues {
- int8 width;
- int8 height;
+ int16 width;
+ int16 height;
};
// This enumeration allows for non-default formats for the weights array
@@ -734,10 +734,10 @@ struct ConvParams {
PaddingType padding_type;
PaddingValues padding_values;
// TODO(starka): This was just "stride", so check that width+height is OK.
- int8 stride_width;
- int8 stride_height;
- int8 dilation_width_factor;
- int8 dilation_height_factor;
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -745,8 +745,12 @@ struct ConvParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DepthToSpaceParams {
@@ -756,8 +760,8 @@ struct DepthToSpaceParams {
struct DepthwiseParams {
PaddingType padding_type;
PaddingValues padding_values;
- int8 stride;
- int8 depth_multiplier;
+ int16 stride;
+ int16 depth_multiplier;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -765,8 +769,12 @@ struct DepthwiseParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DequantizationParams {
@@ -787,13 +795,17 @@ struct FullyConnectedParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
FullyConnectedWeightsFormat weights_format;
};
struct GatherParams {
- int8 input_rank;
+ int16 input_rank;
int16 axis;
};
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index c66959fdf4..14296d3a9f 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV();
+TfLiteRegistration* Register_SQUARE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
+ AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 9156917140..0fdb0a3935 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -74,8 +74,8 @@ void SingleOpModel::SetCustomOp(
CustomOptionsFormat_FLEXBUFFERS));
}
-void SingleOpModel::BuildInterpreter(
- std::vector<std::vector<int>> input_shapes) {
+void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16) {
auto opcodes = builder_.CreateVector(opcodes_);
auto operators = builder_.CreateVector(operators_);
auto tensors = builder_.CreateVector(tensors_);
@@ -113,6 +113,8 @@ void SingleOpModel::BuildInterpreter(
CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
}
+ interpreter_->SetAllowFp16PrecisionForFp32(allow_fp32_relax_to_fp16);
+
// Modify delegate with function.
if (apply_delegate_fn_) {
apply_delegate_fn_(interpreter_.get());
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index bedbe93ae6..84deb0e0e8 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -182,7 +182,8 @@ class SingleOpModel {
// Build the interpreter for this model. Also, resize and allocate all
// tensors given the shapes of the inputs.
- void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+ void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16 = false);
void Invoke();
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 241865b3d8..6311d60b91 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -177,6 +177,11 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
namespace {
template <class T>
std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
+ // Initialize shape of tensors with null shape. Empty vectors are converted
+ // to nullptr for models that are constructed via flatbuffers::Pack.
+ if (flat_array == nullptr) {
+ return {};
+ }
std::vector<int> ret(flat_array->Length());
for (int i = 0; i < flat_array->Length(); i++) {
ret[i] = flat_array->Get(i);
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 81dd459223..687944023b 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -364,6 +364,9 @@ typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)(
ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
uint32_t outputCount, const uint32_t* outputs);
+typedef int (*ANeuralNetworksModel_relaxComputationFloat32toFloat16_fn)(
+ ANeuralNetworksModel* model, bool allow);
+
typedef int (*ANeuralNetworksExecution_create_fn)(
ANeuralNetworksCompilation* compilation,
ANeuralNetworksExecution** execution);
@@ -656,6 +659,34 @@ inline int ANeuralNetworksModel_identifyInputsAndOutputs(
}
/**
+ * Specifies whether {@link ANEURALNETWORKS_TENSOR_FLOAT32} is allowed to be
+ * calculated with range and/or precision as low as that of the IEEE 754 16-bit
+ * floating-point format. By default, {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * must be calculated using at least the range and precision of the IEEE 754
+ * 32-bit floating-point format.
+ *
+ * @param model The model to be modified.
+ * @param allow 'true' indicates {@link ANEURALNETWORKS_TENSOR_FLOAT32} may be
+ * calculated with range and/or precision as low as that of the
+ * IEEE 754 16-bit floating point format. 'false' indicates
+ * {@link ANEURALNETWORKS_TENSOR_FLOAT32} must be calculated using
+ * at least the range and precision of the IEEE 754 32-bit floating
+ * point format.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * Available since API level 28.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ */
+inline int ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ ANeuralNetworksModel* model, bool allow) {
+ LOAD_FUNCTION(ANeuralNetworksModel_relaxComputationFloat32toFloat16);
+ EXECUTE_FUNCTION_RETURN(model, allow);
+}
+
+/**
* Create a {@link ANeuralNetworksCompilation} to compile the given model.
* This only creates the object. Compilation is only performed once
* {@link ANeuralNetworksCompilation_start} is invoked.
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 817486e898..f814b90d66 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -672,6 +672,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_UNPACK:
case tflite::BuiltinOperator_FLOOR_DIV:
case tflite::BuiltinOperator_REDUCE_ANY:
+ case tflite::BuiltinOperator_SQUARE:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
@@ -757,6 +758,11 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
static_cast<uint32_t>(augmented_outputs.size()),
reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
+
+ if (GetAndroidSdkVersionCached() >= 28) {
+ CHECK_NN(ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ nn_model_, interpreter->GetAllowFp16PrecisionForFp32()));
+ }
CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
}
if (!nn_compiled_model_) {
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1c5516ae7c..1f48a826d4 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
import os as _os
import platform as _platform
import subprocess as _subprocess
@@ -30,7 +32,6 @@ from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
-
# Lazy load since some of the performance benchmark skylark rules
# break dependencies.
_toco_python = LazyLoader(
@@ -52,6 +53,31 @@ if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
_toco_from_proto_bin = "toco_from_protos"
+class ConverterMode(enum.Enum):
+ """Enum class defining the converters available to generate TFLite models.
+
+ WARNING: Experimental interface, subject to change.
+ """
+ # Convert model using TOCO such that all ops are TensorFlow Lite native ops.
+ #
+ # This is the only supported mode for any models that contain operations that
+ # cannot be resolved in TensorFlow.
+ DEFAULT = "DEFAULT"
+
+ # Convert model using TOCO such that only unsupported operations are
+ # represented as TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED = "TOCO_EXTENDED"
+
+ # Convert model using TOCO such that all operations are represented as
+ # TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+
+ def __str__(self):
+ return self.value
+
+
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
"""Convert `input_data_str` according to model and toco parameters.
@@ -128,7 +154,8 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges=False,
post_training_quantize=False,
dump_graphviz_dir=None,
- dump_graphviz_video=False):
+ dump_graphviz_video=False,
+ converter_mode=ConverterMode.DEFAULT):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -183,6 +210,8 @@ def build_toco_convert_protos(input_tensors,
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Returns:
model_flags, toco_flags: two protocol buffers describing the conversion
@@ -211,6 +240,11 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
+ if converter_mode == ConverterMode.TOCO_EXTENDED:
+ toco.allow_eager_ops = True
+ elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
+ toco.allow_eager_ops = True
+ toco.force_eager_ops = True
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
@@ -301,9 +335,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
Raises:
Defined in `build_toco_convert_protos`.
"""
- model_flags, toco_flags = build_toco_convert_protos(input_tensors,
- output_tensors,
- *args, **kwargs)
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors, output_tensors, *args, **kwargs)
data = toco_convert_protos(model_flags.SerializeToString(),
toco_flags.SerializeToString(),
input_data.SerializeToString())
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index 59f537b82a..40a8b5fafb 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -188,7 +188,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
return output
output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# check if identities have been put into the graph (2 input, 1 output,
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
@@ -215,7 +215,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (3) and output (2) => 3 + 2 = 5
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
@@ -242,7 +242,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
output = array_ops.identity(
math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
@@ -279,7 +279,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
aggregate=op_hint.OpHint.AGGREGATE_STACK)
res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
custom.add_outputs([res])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(self._get_input_index(a), 0)
self.assertEqual(self._get_sort_index(a), 0)
self.assertEqual(self._get_input_index(b), 1)
@@ -294,7 +294,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
b = custom.add_input(b) # should auto assign 0
a = custom.add_input(a, index_override=1)
c = custom.add_input(c) # should auto assign 2
- with self.test_session():
+ with self.cached_session():
self.assertEqual(self._get_input_index(a), 1)
self.assertEqual(self._get_input_index(b), 0)
self.assertEqual(self._get_input_index(c), 2)
@@ -320,7 +320,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
curr = array_ops.stack([c0, c1])
output = array_ops.identity(curr, name="FINAL_OUTPUT")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
graph_def=sess.graph_def)
self.assertCountEqual(
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 44dfb97b84..2be24455d8 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,6 +40,7 @@ from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import ConverterMode
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
@@ -113,6 +114,8 @@ class TocoConverter(object):
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Example usage:
@@ -179,6 +182,7 @@ class TocoConverter(object):
self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ self.converter_mode = ConverterMode.DEFAULT
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@@ -389,6 +393,7 @@ class TocoConverter(object):
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
+ ConverterMode option is unsupported for the model.
"""
# Checks dimensions in input tensor.
if self._has_valid_tensors():
@@ -439,12 +444,18 @@ class TocoConverter(object):
# Converts model.
if self._has_valid_tensors():
+ converter_kwargs["converter_mode"] = self.converter_mode
result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
**converter_kwargs)
else:
+ # Graphs without valid tensors cannot be loaded into tf.Session since they
+ # contain TFLite operation(s) that cannot be resolved in TensorFlow.
+ if self.converter_mode != ConverterMode.DEFAULT:
+ raise ValueError("This model can only be converted with the default "
+ "converter.")
result = _toco_convert_graph_def(
input_data=self._graph_def,
input_arrays_with_shape=self._input_arrays_with_shape,
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 3f8ea433ff..f112ed5cdd 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -402,6 +402,28 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
+ def testExtendedMode(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensures the model contains TensorFlow ops.
+ # TODO(nupurgarg): Check values once there is a Python delegate interface.
+ interpreter = Interpreter(model_content=tflite_model)
+ with self.assertRaises(RuntimeError) as error:
+ interpreter.allocate_tensors()
+ self.assertIn(
+ 'Regular TensorFlow ops are not supported by this interpreter. Make '
+ 'sure you invoke the Eager delegate before inference.',
+ str(error.exception))
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index cc08ed3fe9..c0ff7f37f9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -140,8 +140,11 @@ def _convert_model(flags):
if flags.change_concat_input_ranges:
converter.change_concat_input_ranges = (
flags.change_concat_input_ranges == "TRUE")
+
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
+ if flags.converter_mode:
+ converter.converter_mode = flags.converter_mode
if flags.post_training_quantize:
converter.post_training_quantize = flags.post_training_quantize
@@ -363,6 +366,8 @@ def run_main(_):
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
+
+ # Permitted ops flags.
parser.add_argument(
"--allow_custom_ops",
action="store_true",
@@ -371,6 +376,12 @@ def run_main(_):
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
+ parser.add_argument(
+ "--converter_mode",
+ type=lite.ConverterMode,
+ choices=list(lite.ConverterMode),
+ help=("Experimental flag, subject to change. ConverterMode indicating "
+ "which converter to use. (default ConverterMode.DEFAULT)"))
# Logging flags.
parser.add_argument(
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 28a7e50003..55bf2c48b9 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -56,6 +56,20 @@ flatbuffer_cc_library(
srcs = ["schema.fbs"],
)
+# Generic schema for inference on device (but with reflections makes bigger).
+flatbuffer_cc_library(
+ name = "schema_fbs_with_reflection",
+ srcs = ["schema.fbs"],
+ flatc_args = [
+ "--reflect-types",
+ "--reflect-names",
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+ ],
+ gen_reflections = True,
+ out_prefix = "reflection/",
+)
+
# Schema test to make sure we don't introduce backward incompatible changes
# to schemas.
cc_test(
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index cf66403ec9..f0db22d581 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -173,6 +173,7 @@ enum BuiltinOperator : byte {
REDUCE_MIN = 89,
FLOOR_DIV = 90,
REDUCE_ANY = 91,
+ SQUARE = 92,
}
// Options for the builtin operators.
@@ -242,6 +243,7 @@ union BuiltinOptions {
LogicalNotOptions,
UnpackOptions,
FloorDivOptions,
+ SquareOptions,
}
enum Padding : byte { SAME, VALID }
@@ -274,11 +276,15 @@ table Pool2DOptions {
}
table DepthwiseConv2DOptions {
+ // Parameters for DepthwiseConv version 1 or above.
padding:Padding;
stride_w:int;
stride_h:int;
depth_multiplier:int;
fused_activation_function:ActivationFunctionType;
+ // Parameters for DepthwiseConv version 2 or above.
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
}
table ConcatEmbeddingsOptions {
@@ -579,6 +585,9 @@ table UnpackOptions {
table FloorDivOptions {
}
+table SquareOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 6d9630d75e..8c086a5e67 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -226,6 +226,9 @@ struct UnpackOptionsT;
struct FloorDivOptions;
struct FloorDivOptionsT;
+struct SquareOptions;
+struct SquareOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -383,11 +386,12 @@ enum BuiltinOperator {
BuiltinOperator_REDUCE_MIN = 89,
BuiltinOperator_FLOOR_DIV = 90,
BuiltinOperator_REDUCE_ANY = 91,
+ BuiltinOperator_SQUARE = 92,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY
+ BuiltinOperator_MAX = BuiltinOperator_SQUARE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[92] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -479,7 +483,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
BuiltinOperator_UNPACK,
BuiltinOperator_REDUCE_MIN,
BuiltinOperator_FLOOR_DIV,
- BuiltinOperator_REDUCE_ANY
+ BuiltinOperator_REDUCE_ANY,
+ BuiltinOperator_SQUARE
};
return values;
}
@@ -578,6 +583,7 @@ inline const char **EnumNamesBuiltinOperator() {
"REDUCE_MIN",
"FLOOR_DIV",
"REDUCE_ANY",
+ "SQUARE",
nullptr
};
return names;
@@ -655,11 +661,12 @@ enum BuiltinOptions {
BuiltinOptions_LogicalNotOptions = 63,
BuiltinOptions_UnpackOptions = 64,
BuiltinOptions_FloorDivOptions = 65,
+ BuiltinOptions_SquareOptions = 66,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions
+ BuiltinOptions_MAX = BuiltinOptions_SquareOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[67] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -726,7 +733,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
BuiltinOptions_LogicalAndOptions,
BuiltinOptions_LogicalNotOptions,
BuiltinOptions_UnpackOptions,
- BuiltinOptions_FloorDivOptions
+ BuiltinOptions_FloorDivOptions,
+ BuiltinOptions_SquareOptions
};
return values;
}
@@ -799,6 +807,7 @@ inline const char **EnumNamesBuiltinOptions() {
"LogicalNotOptions",
"UnpackOptions",
"FloorDivOptions",
+ "SquareOptions",
nullptr
};
return names;
@@ -1073,6 +1082,10 @@ template<> struct BuiltinOptionsTraits<FloorDivOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
};
+template<> struct BuiltinOptionsTraits<SquareOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1624,6 +1637,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_FloorDivOptions ?
reinterpret_cast<const FloorDivOptionsT *>(value) : nullptr;
}
+ SquareOptionsT *AsSquareOptions() {
+ return type == BuiltinOptions_SquareOptions ?
+ reinterpret_cast<SquareOptionsT *>(value) : nullptr;
+ }
+ const SquareOptionsT *AsSquareOptions() const {
+ return type == BuiltinOptions_SquareOptions ?
+ reinterpret_cast<const SquareOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2318,12 +2339,16 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable {
int32_t stride_h;
int32_t depth_multiplier;
ActivationFunctionType fused_activation_function;
+ int32_t dilation_w_factor;
+ int32_t dilation_h_factor;
DepthwiseConv2DOptionsT()
: padding(Padding_SAME),
stride_w(0),
stride_h(0),
depth_multiplier(0),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ dilation_w_factor(1),
+ dilation_h_factor(1) {
}
};
@@ -2334,7 +2359,9 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VT_STRIDE_W = 6,
VT_STRIDE_H = 8,
VT_DEPTH_MULTIPLIER = 10,
- VT_FUSED_ACTIVATION_FUNCTION = 12
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_W_FACTOR = 14,
+ VT_DILATION_H_FACTOR = 16
};
Padding padding() const {
return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0));
@@ -2351,6 +2378,12 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ int32_t dilation_w_factor() const {
+ return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
+ }
+ int32_t dilation_h_factor() const {
+ return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING) &&
@@ -2358,6 +2391,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
verifier.EndTable();
}
DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2383,6 +2418,12 @@ struct DepthwiseConv2DOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_dilation_w_factor(int32_t dilation_w_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2401,8 +2442,12 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
int32_t stride_w = 0,
int32_t stride_h = 0,
int32_t depth_multiplier = 0,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1,
+ int32_t dilation_h_factor = 1) {
DepthwiseConv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_depth_multiplier(depth_multiplier);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
@@ -5803,6 +5848,46 @@ inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(
flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SquareOptionsT : public flatbuffers::NativeTable {
+ typedef SquareOptions TableType;
+ SquareOptionsT() {
+ }
+};
+
+struct SquareOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SquareOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ SquareOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SquareOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SquareOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit SquareOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SquareOptionsBuilder &operator=(const SquareOptionsBuilder &);
+ flatbuffers::Offset<SquareOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SquareOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ SquareOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -6131,6 +6216,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FloorDivOptions *builtin_options_as_FloorDivOptions() const {
return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast<const FloorDivOptions *>(builtin_options()) : nullptr;
}
+ const SquareOptions *builtin_options_as_SquareOptions() const {
+ return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast<const SquareOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6422,6 +6510,10 @@ template<> inline const FloorDivOptions *Operator::builtin_options_as<FloorDivOp
return builtin_options_as_FloorDivOptions();
}
+template<> inline const SquareOptions *Operator::builtin_options_as<SquareOptions>() const {
+ return builtin_options_as_SquareOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -6996,6 +7088,8 @@ inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const
{ auto _e = stride_h(); _o->stride_h = _e; };
{ auto _e = depth_multiplier(); _o->depth_multiplier = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; };
+ { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; };
}
inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7011,13 +7105,17 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
auto _stride_h = _o->stride_h;
auto _depth_multiplier = _o->depth_multiplier;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _dilation_w_factor = _o->dilation_w_factor;
+ auto _dilation_h_factor = _o->dilation_h_factor;
return tflite::CreateDepthwiseConv2DOptions(
_fbb,
_padding,
_stride_w,
_stride_h,
_depth_multiplier,
- _fused_activation_function);
+ _fused_activation_function,
+ _dilation_w_factor,
+ _dilation_h_factor);
}
inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -8661,6 +8759,29 @@ inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::F
_fbb);
}
+inline SquareOptionsT *SquareOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SquareOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SquareOptions::UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<SquareOptions> SquareOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSquareOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SquareOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateSquareOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -9110,6 +9231,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9388,6 +9513,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9654,6 +9783,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const FloorDivOptionsT *>(value);
return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptionsT *>(value);
+ return CreateSquareOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9920,6 +10053,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new FloorDivOptionsT(*reinterpret_cast<FloorDivOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SquareOptions: {
+ value = new SquareOptionsT(*reinterpret_cast<SquareOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10252,6 +10389,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<SquareOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index aad1ecaeb6..a4736bfee9 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -7,7 +7,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow/contrib/lite:build_def.bzl",
"gen_zip_test",
- "generated_test_models",
+ "generated_test_models_all",
)
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
load(
@@ -29,6 +29,7 @@ load(
"--unzip_binary_path=/usr/bin/unzip",
],
}),
+ conversion_mode = conversion_mode,
data = [
":zip_%s" % test_name,
],
@@ -36,7 +37,7 @@ load(
tags = [
"gen_zip_test",
"no_oss",
- "tflite_not_portable",
+ "tflite_not_portable_intentional",
],
test_name = test_name,
deps = [
@@ -59,7 +60,7 @@ load(
"//tensorflow/core:android_tensorflow_test_lib",
],
}),
-) for test_name in generated_test_models()]
+) for conversion_mode, test_name in generated_test_models_all()]
test_suite(
name = "generated_zip_tests",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 32f02a4f6c..3754b58b23 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -80,7 +80,10 @@ parser.add_argument(
"--save_graphdefs",
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
-
+parser.add_argument(
+ "--run_with_extended",
+ action="store_true",
+ help="Whether the TFLite Extended converter is being used.")
RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3
@@ -320,10 +323,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
output tflite model, log_txt from conversion
or None, log_txt if it did not convert properly.
"""
+ input_arrays = [x[0] for x in input_tensors]
data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
opts = toco_options(
data_types=data_types,
- input_arrays=[x[0] for x in input_tensors],
+ input_arrays=input_arrays,
shapes=[x[1] for x in input_tensors],
output_arrays=output_tensors,
extra_toco_options=extra_toco_options)
@@ -335,6 +339,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
graphdef_file.flush()
# TODO(aselle): Switch this to subprocess at some point.
+ if "pb2lite" in bin_path and FLAGS.run_with_extended:
+ opts = ("--input_arrays={0} --output_arrays={1}".format(
+ ",".join(input_arrays), ",".join(output_tensors)))
+ elif FLAGS.run_with_extended:
+ opts += " --allow_eager_ops --force_eager_ops"
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
(bin_path, graphdef_file.name, output_file.name, opts,
stdout_file.name))
@@ -1425,6 +1434,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
"filter_size": [[1, 1], [1, 2], [3, 3]],
"strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
"channel_multiplier": [1, 2],
"rate": [[1, 1]],
"padding": ["SAME", "VALID"],
@@ -1435,6 +1445,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3]],
"filter_size": [[1, 1]],
"strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]],
"channel_multiplier": [2],
"rate": [[2, 2]], # Only [1, 1] is supported
"padding": ["SAME"],
@@ -1502,7 +1513,7 @@ def make_split_tests(zip_path):
dtype=tf.float32, name="input", shape=parameters["input_shape"])
out = tf.split(
input_tensor, parameters["num_or_size_splits"], parameters["axis"])
- return [input_tensor], out
+ return [input_tensor], [out[0]]
def build_inputs(parameters, sess, inputs, outputs):
values = [create_tensor_data(np.float32, parameters["input_shape"])]
@@ -2510,10 +2521,12 @@ def make_topk_tests(zip_path):
shape=parameters["input_shape"])
if parameters["input_k"] is not None:
k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+ inputs = [input_value, k]
else:
k = tf.constant(3, name="k")
+ inputs = [input_value]
out = tf.nn.top_k(input_value, k)
- return [input_value, k], [out[1]]
+ return inputs, [out[1]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
@@ -2871,6 +2884,11 @@ def make_rsqrt_tests(zip_path):
return _make_elementwise_tests(tf.rsqrt)(zip_path)
+def make_square_tests(zip_path):
+ """Make a set of tests to do square."""
+ return _make_elementwise_tests(tf.square)(zip_path)
+
+
def make_where_tests(zip_path):
"""Make a set of tests to do where."""
@@ -3208,7 +3226,7 @@ def make_unpack_tests(zip_path):
input_tensor = tf.placeholder(
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
- return [input_tensor], outs
+ return [input_tensor], [outs[0]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
@@ -3286,7 +3304,11 @@ def main(unused_args):
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
- test_function = ("make_%s_tests" % out.replace(".zip", ""))
+ # Some zip filenames contain a postfix identifying the conversion mode. The
+ # list of valid conversion modes is defined in
+ # generated_test_conversion_modes() in build_def.bzl.
+ test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
+ "pb2lite", "").replace("toco-extended", "").rstrip("_")))
if test_function not in globals():
raise RuntimeError("Can't find a test function to create %r. Tried %r" %
(out, test_function))
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index bea90f1ce8..96b88b60fc 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -347,6 +347,7 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -407,8 +408,11 @@ tf_cc_binary(
":toco_port",
":toco_tooling",
":types_proto_cc",
- "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
],
)
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4b3c..efc1007925 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -1079,6 +1080,25 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
+ } else {
+ const tensorflow::OpDef* op_def = nullptr;
+ if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
+ }
+ }
+ }
+ if (op->output_data_types.empty()) {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
+ }
}
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 90e6f698ef..da248826a7 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -48,6 +49,17 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
+Status ImportNode(const NodeDef& node, Model* model) {
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+ Model model;
+ return ImportNode(node, &model);
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -108,12 +120,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
+};
+
+class TypeImportTest : public ::testing::TestWithParam<
+ std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+ TypeImportTest() {}
+
+ void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+ NodeDef* node) {
+ node->set_op(op_name);
+ node->set_name("Node1");
+
+ node->add_input();
+ node->set_input(0, "Node0");
- Status ImportNode(const NodeDef& node) {
- Model model;
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
- converter);
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["T"] = dtype_attr;
}
};
@@ -166,5 +190,47 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+ return {{DT_FLOAT, ArrayDataType::kFloat},
+ {DT_INT32, ArrayDataType::kInt32},
+ {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+ NodeDef node;
+ BuildUnaryNode("Atan", GetParam().first, &node);
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+ ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+ // Create a unary op with no Type ("T") annotation.
+ NodeDef node;
+ node.set_op("Atan");
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_TRUE(op->output_data_types.empty());
+}
+
} // namespace
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e37f6..164b70f2df 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator {
int stride_height = 0;
int stride_width = 0;
int depth_multiplier = 0;
+ // A dilation_rate of 0 is invalid and this field is an optional attribute.
+ // Thus initializing it to 1 to allow default conv behavior when the
+ // attribute is not present.
+ int dilation_width_factor = 1;
+ int dilation_height_factor = 1;
};
// Depth-to-space transform operator.
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
index 3761e0095e..75c1c8970c 100644
--- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -50,7 +50,7 @@ class TocoFromProtosTest(googletest.TestCase):
toco_flags.output_format = toco_flags_pb2.TFLITE
toco_flags.inference_input_type = types_pb2.FLOAT
toco_flags.inference_type = types_pb2.FLOAT
- toco_flags.allow_custom_ops = True;
+ toco_flags.allow_custom_ops = True
model_flags = model_flags_pb2.ModelFlags()
input_array = model_flags.input_arrays.add()
input_array.name = TensorName(in_tensor)
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index eb0f7c443a..1061e7c7c4 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@ class DepthwiseConvolution
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
- op.depth_multiplier, activation_function);
+ op.depth_multiplier, activation_function, op.dilation_width_factor,
+ op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@ class DepthwiseConvolution
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+ }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
@@ -1488,6 +1498,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"SQRT", OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
"RSQRT", OperatorType::kRsqrt));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
+ "SQUARE", OperatorType::kSquare));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 519a3a4e01..72e50a9aed 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
+ CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
+ OperatorType::kSquare);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 02039922b4..ef4f0fa80d 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -232,6 +232,46 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
return total_input_bytes;
}
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+ auto interpreter_inputs = interpreter->inputs();
+ // Set the values of the input tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ std::vector<int> sizes = input.shape;
+
+ // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+ if (t->type == kTfLiteFloat32) {
+ FillRandomValue<float>(
+ interpreter->typed_tensor<float>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+ } else if (t->type == kTfLiteInt32) {
+ // TODO(yunluli): This is currently only used for handling embedding input
+ // for speech models. Generalize if necessary.
+ FillRandomValue<int32_t>(
+ interpreter->typed_tensor<int32_t>(i),
+ std::vector<int32_t>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<int32_t>(rand()) % 100; });
+ } else if (t->type == kTfLiteUInt8) {
+ FillRandomValue<uint8_t>(
+ interpreter->typed_tensor<uint8_t>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<uint8_t>(rand()) % 255; });
+ } else if (t->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, sizes, []() {
+ return "we're have some friends over saturday to hang out in the yard";
+ });
+ buffer.WriteToTensor(interpreter->tensor(i));
+ } else {
+ TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+ << " of type " << t->type;
+ }
+ }
+}
+
void BenchmarkTfLiteModel::Init() {
std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -305,36 +345,6 @@ void BenchmarkTfLiteModel::Init() {
if (interpreter->AllocateTensors() != kTfLiteOk) {
TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
}
-
- // Set the values of the input tensors.
- for (int j = 0; j < inputs.size(); ++j) {
- const InputLayerInfo& input = inputs[j];
- int i = interpreter_inputs[j];
- TfLiteTensor* t = interpreter->tensor(i);
- std::vector<int> sizes = input.shape;
-
- // TODO(ahentz): below we ignore the O-th dimension (number of batches).
- if (t->type == kTfLiteFloat32) {
- FillRandomValue<float>(
- interpreter->typed_tensor<float>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
- } else if (t->type == kTfLiteUInt8) {
- FillRandomValue<uint8_t>(
- interpreter->typed_tensor<uint8_t>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<uint8_t>(rand()) % 255; });
- } else if (t->type == kTfLiteString) {
- tflite::DynamicBuffer buffer;
- FillRandomString(&buffer, sizes, []() {
- return "we're have some friends over saturday to hang out in the yard";
- });
- buffer.WriteToTensor(interpreter->tensor(i));
- } else {
- TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
- << " of type " << t->type;
- }
- }
}
void BenchmarkTfLiteModel::RunImpl() {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 4c4320a998..8541512bc8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -69,6 +69,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<int> shape;
};
+ protected:
+ void PrepareInputsAndOutputs() override;
+
private:
#ifdef TFLITE_EXTENDED
std::unique_ptr<EagerDelegate> delegate_;
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index b863108aa4..d02d78bf53 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -206,6 +206,14 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
continue;
}
+ // Some tensors may have a null buffer vector, indicating an intermediate
+ // array.
+ if (model->buffers[tensor->buffer]->data.data() == nullptr) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " because it has no allocated buffer.";
+ continue;
+ }
+
TensorInfo tensor_info;
tensor_info.eval_hybrid = eval_hybrid;
tensor_info.op_input_idx = op_input_idx;
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index 597dede63b..d7eea79399 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -202,7 +202,7 @@ class TensorMapper(object):
html += str(i) + " "
html += tensor["name"] + " "
html += str(tensor["type"]) + " "
- html += repr(tensor["shape"]) + "<br>"
+ html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + "<br>"
html += "</span>"
html += repr(x)
html += "</span>"
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
new file mode 100644
index 0000000000..a96e2c4e1b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -0,0 +1,702 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "6Y8E0lw5eYWm"
+ },
+ "source": [
+ "# Post Training Quantization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CIGrZZPTZVeO"
+ },
+ "source": [
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ "\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BTC1rDAuei_1"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+ "converting weights to 8 bit precision as part of model conversion from\n",
+ "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
+ "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
+ "fly quantization and dequantization of activations to allow for:\n",
+ "\n",
+ "1. Using quantized kernels for faster implementation when available.\n",
+ "\n",
+ "2. Mixing of floating-point kernels with quantized kernels for different parts\n",
+ " of the graph.\n",
+ "\n",
+ "Note that the activations are always stored in floating point. For ops that\n",
+ "support quantized kernels, the activations are quantized to 8 bits of precision\n",
+ "dynamically prior to processing and are de-quantized to float precision after\n",
+ "processing. Depending on the model being converted, this can give a speedup over\n",
+ "pure floating point computation.\n",
+ "\n",
+ "In contrast to\n",
+ "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n",
+ ", the weights are quantized post training and the activations are quantized dynamically \n",
+ "at inference in this method.\n",
+ "Therefore, the model weights are not retrained to compensate for quantization\n",
+ "induced errors. It is important to check the accuracy of the quantized model to\n",
+ "ensure that the degradation is acceptable.\n",
+ "\n",
+ "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n",
+ "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+ "with weight quantization. We finally check the\n",
+ "accuracy of the converted model and compare it to the original saved model. We\n",
+ "run the training script mnist.py from\n",
+ "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2XsEP17Zelz9"
+ },
+ "source": [
+ "## Building an MNIST model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "dDqqUIZjZjac"
+ },
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "gyqAw1M9lyab"
+ },
+ "outputs": [],
+ "source": [
+ "! pip uninstall -y tensorflow\n",
+ "! pip install -U tf-nightly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WsN6s5L1ieNl"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "00U0taBoe-w7"
+ },
+ "outputs": [],
+ "source": [
+ "! git clone --depth 1 https://github.com/tensorflow/models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "4XZPtSh-fUOc"
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "if sys.version_info.major \u003e= 3:\n",
+ " import pathlib\n",
+ "else:\n",
+ " import pathlib2 as pathlib\n",
+ "\n",
+ "# Add `models` to the python path.\n",
+ "models_path = os.path.join(os.getcwd(), \"models\")\n",
+ "sys.path.append(models_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eQ6Q0qqKZogR"
+ },
+ "source": [
+ "### Train and export the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eMsw_6HujaqM"
+ },
+ "outputs": [],
+ "source": [
+ "saved_models_root = \"/tmp/mnist_saved_model\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "hWSAjQWagIHl"
+ },
+ "outputs": [],
+ "source": [
+ "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+ "# Note: channels_last is required here or the conversion may fail. \n",
+ "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5NMaNZQCkW9X"
+ },
+ "source": [
+ "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xl8_fzVAZwOh"
+ },
+ "source": [
+ "### Convert to a TFLite model\n",
+ "\n",
+ "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Xp5oClaZkbtn"
+ },
+ "outputs": [],
+ "source": [
+ "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+ "saved_model_dir"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "AT8BgkKmljOy"
+ },
+ "source": [
+ "Using the python `TocoConverter`, the saved model can be converted into a TFLite model.\n",
+ "\n",
+ "First load the model using the `TocoConverter`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "_i8B2nDZmAgQ"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "F2o2ZfF0aiCx"
+ },
+ "source": [
+ "Write it out to a tflite file:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vptWZq2xnclo"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+ "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ie9pQaQrn5ue"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+ "tflite_model_file.write_bytes(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7BONhYtYocQY"
+ },
+ "source": [
+ "To quantize the model on export, set the `post_training_quantize` flag:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g8PUvLWDlmmz"
+ },
+ "outputs": [],
+ "source": [
+ "# Note: If you don't have a recent tf-nightly installed, the\n",
+ "# \"post_training_quantize\" line will have no effect.\n",
+ "tf.logging.set_verbosity(tf.logging.INFO)\n",
+ "converter.post_training_quantize = True\n",
+ "tflite_quant_model = converter.convert()\n",
+ "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+ "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PhMmUTl4sbkz"
+ },
+ "source": [
+ "Note how the resulting file, with `post_training_quantize` set, is approximately `1/4` the size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "JExfcfLDscu4"
+ },
+ "outputs": [],
+ "source": [
+ "!ls -lh {tflite_models_dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L8lQHMp_asCq"
+ },
+ "source": [
+ "## Run the TFLite models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-5l6-ciItvX6"
+ },
+ "source": [
+ "We can run the TensorFlow Lite model using the python TensorFlow Lite\n",
+ "Interpreter. \n",
+ "\n",
+ "### load the test data\n",
+ "\n",
+ "First let's load the mnist test data to feed to it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eTIuU07NuKFL"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+ "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n",
+ "\n",
+ "# Note: If you change the batch size, then use \n",
+ "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n",
+ "# the interpreter.\n",
+ "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Ap_jE7QRvhPf"
+ },
+ "source": [
+ "### Load the model into an interpreter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Jn16Rc23zTss"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n",
+ "interpreter.allocate_tensors()\n",
+ "input_index = interpreter.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter.get_output_details()[0][\"index\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "J8Pztk1mvNVL"
+ },
+ "outputs": [],
+ "source": [
+ "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+ "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Afl6yGvWyqAr"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter_quant.allocate_tensors()\n",
+ "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2opUt_JTdyEu"
+ },
+ "source": [
+ "### Test the model on one image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "AKslvo2kwWac"
+ },
+ "outputs": [],
+ "source": [
+ "for img, label in mnist_ds.take(1):\n",
+ " break\n",
+ "\n",
+ "interpreter.set_tensor(input_index, img)\n",
+ "interpreter.invoke()\n",
+ "predictions = interpreter.get_tensor(output_index)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XZClM2vo3_bm"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pylab as plt\n",
+ "\n",
+ "plt.imshow(img[0])\n",
+ "template = \"True:{true}, predicted:{predict}\"\n",
+ "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+ " predict=str(predictions[0,0])))\n",
+ "plt.grid(False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LwN7uIdCd8Gw"
+ },
+ "source": [
+ "### Evaluate the models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "05aeAuWjvjPx"
+ },
+ "outputs": [],
+ "source": [
+ "def eval_model(interpreter, mnist_ds):\n",
+ " total_seen = 0\n",
+ " num_correct = 0\n",
+ "\n",
+ " for img, label in mnist_ds:\n",
+ " total_seen += 1\n",
+ " interpreter.set_tensor(input_index, img)\n",
+ " interpreter.invoke()\n",
+ " predictions = interpreter.get_tensor(output_index)\n",
+ " if predictions == label.numpy():\n",
+ " num_correct += 1\n",
+ "\n",
+ " if total_seen % 500 == 0:\n",
+ " print(\"Accuracy after %i images: %f\" %\n",
+ " (total_seen, float(num_correct) / float(total_seen)))\n",
+ "\n",
+ " return float(num_correct) / float(total_seen)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "DqXBnDfJ7qxL"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter_quant, mnist_ds))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Km3cY9ry8ZlG"
+ },
+ "source": [
+ "We can repeat the evaluation on the weight quantized model to obtain:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "-9cnwiPp6EGm"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter_quant, mnist_ds))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L7lfxkor8pgv"
+ },
+ "source": [
+ "\n",
+ "In this example, we have compressed model with no difference in the accuracy."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "M0o1FtmWeKZm"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "## Optimizing an existing model\n",
+ "\n",
+ "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
+ " Pre-trained frozen graph for resnet-v2-101 is available at the\n",
+ " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n",
+ "\n",
+ "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "v5p5VcNPjILQ"
+ },
+ "outputs": [],
+ "source": [
+ "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
+ "archive_path = pathlib.Path(archive_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-sxnXQuC4ThD"
+ },
+ "source": [
+ "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g_Q_OMEJ4LIc"
+ },
+ "outputs": [],
+ "source": [
+ "! cat {archive_path}/resnet_v2_101_299_info.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ujCAFhqm-C6H"
+ },
+ "outputs": [],
+ "source": [
+ "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
+ "input_arrays = [\"input\"] \n",
+ "output_arrays = [\"output\"]\n",
+ "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n",
+ " str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
+ "converter.post_training_quantize = True\n",
+ "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
+ "resnet_tflite_file.write_bytes(converter.convert())\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vhOjeg1x9Knp"
+ },
+ "outputs": [],
+ "source": [
+ "archive_dir = str(archive_path.parent)\n",
+ "!ls -lh {archive_dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qqHLaqFMCjRZ"
+ },
+ "source": [
+ "\n",
+ "The model size reduces from 171 MB to 43 MB.\n",
+ "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n",
+ "\n",
+ "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "post-training-quant.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
+ },
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0a54bb1f5e..89b538d1ba 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -44,7 +44,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -68,7 +68,7 @@ class HashTableOpTest(test.TestCase):
self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -86,7 +86,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -105,7 +105,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -122,7 +122,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -165,7 +165,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -188,7 +188,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -210,7 +210,7 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup.HashTable(
@@ -218,7 +218,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup.HashTable(
lookup.KeyValueTensorInitializer(
@@ -232,7 +232,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -244,7 +244,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -283,7 +283,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testHashTableInt32String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int32)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -301,7 +301,7 @@ class HashTableOpTest(test.TestCase):
class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -470,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
def testMutableHashTableOfTensors(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -500,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
def testMutableHashTableExportInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -531,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(expected_output, output2.eval())
def testMutableHashTableOfTensorsInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
@@ -563,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testMutableHashTableInvalidDefaultValue(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
@@ -571,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
def testMutableHashTableDuplicateInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@@ -589,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([3, 1, -1], result)
def testMutableHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -608,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableInsertHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
@@ -625,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, 3, -1], result)
def testMutableHashTableOfTensorsFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
@@ -646,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase):
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
def testMultipleMutableHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -676,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testMutableHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -693,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -734,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase):
lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
def testMutableHashTableStringFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.5
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
@@ -752,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, default_val], result)
def testMutableHashTableIntFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.0
keys = constant_op.constant([3, 7, 0], dtypes.int64)
values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
@@ -770,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([-1.2, 9.9, default_val], result)
def testMutableHashTableInt64String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int64)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -791,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase):
class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -809,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testBasicBool(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
@@ -827,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([True, True, False], result)
def testLookupUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -843,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMapStringToFloat(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant(["a", "b", "c"], dtypes.string)
values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
@@ -866,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
@@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, -1.5], result)
def testVectorValues(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
dtypes.int64)
@@ -918,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
result)
def testVectorKeys(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
@@ -949,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([10, 11, -1], result)
def testResize(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -977,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1238,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
def testReprobe(self):
- with self.test_session():
+ with self.cached_session():
# Insert 6 keys into a table with 8 buckets.
# The values are chosen to make sure collisions occur when using GCC STL
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
@@ -1263,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
def testCustomEmptyKey(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1281,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testErrors(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
@@ -1328,7 +1328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1339,7 +1339,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -1353,7 +1353,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -1370,7 +1370,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int32)
@@ -1384,7 +1384,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int64)
@@ -1398,7 +1398,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1409,7 +1409,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -1439,7 +1439,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1451,7 +1451,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -1466,7 +1466,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1478,7 +1478,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -1499,21 +1499,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
@@ -1542,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -1553,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -1565,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1575,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_mapping(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -1590,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
@@ -1609,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase):
class StringToIndexTest(test.TestCase):
def test_string_to_index(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1620,7 +1620,7 @@ class StringToIndexTest(test.TestCase):
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
feats = constant_op.constant(["hello", "hola"])
_ = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1630,7 +1630,7 @@ class StringToIndexTest(test.TestCase):
def test_string_to_index_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(
@@ -1651,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table(self):
vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
@@ -1663,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1675,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -1688,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1700,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1713,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1727,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1738,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_value)
@@ -1754,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
class IndexToStringTest(test.TestCase):
def test_index_to_string(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1766,7 +1766,7 @@ class IndexToStringTest(test.TestCase):
feats.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1778,7 +1778,7 @@ class IndexToStringTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([1, 2, 4], dtypes.int64)
feats = lookup.index_to_string(
@@ -1818,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
@@ -1837,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.LINE_NUMBER
value_index = lookup.TextFileIndex.WHOLE_LINE
@@ -1858,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -1880,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -1894,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.WHOLE_LINE
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1907,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1922,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup.HashTable(
@@ -1961,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup.HashTable(
@@ -1971,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -2022,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer("old_file.txt", dtypes.string,
@@ -2049,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -2072,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup.HashTable(
@@ -2090,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2108,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2133,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2154,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2176,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2196,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2217,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2239,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -2294,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2316,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2340,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -2378,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -2407,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -2436,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -2464,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 2a442a8fc8..c0aec09778 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -43,68 +43,68 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.absolute_difference(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2,])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -117,12 +117,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -141,7 +141,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -154,7 +154,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -166,7 +166,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -179,7 +179,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -191,7 +191,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -203,12 +203,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -223,7 +223,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -253,7 +253,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights = [2.3, 2.4, 2.5]
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -268,7 +268,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights_placeholder = array_ops.placeholder(
dtypes.float32, shape=[None, None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -280,12 +280,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -320,7 +320,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -331,7 +331,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -342,7 +342,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -353,7 +353,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -363,7 +363,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -374,7 +374,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -384,7 +384,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -394,7 +394,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -404,12 +404,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -422,7 +422,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -435,7 +435,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -448,7 +448,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -462,7 +462,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -484,7 +484,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -498,7 +498,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None, None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -506,7 +506,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -537,7 +537,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -546,7 +546,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -582,11 +582,11 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -608,7 +608,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -641,33 +641,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = loss_ops.log_loss(tf_predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = loss_ops.log_loss(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -675,7 +675,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -685,7 +685,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -695,7 +695,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -706,7 +706,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -715,7 +715,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -724,12 +724,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._predictions, self._labels, weights)
@@ -742,7 +742,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -756,7 +756,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -769,7 +769,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -780,35 +780,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = loss_ops.hinge_loss(logits, labels).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = loss_ops.hinge_loss(logits, labels)
self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -817,7 +817,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -834,62 +834,62 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_squared_error(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2,])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -914,7 +914,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 9.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -925,14 +925,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
def testGradientWithZeroWeight(self):
@@ -954,7 +954,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -966,7 +966,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -976,7 +976,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -986,7 +986,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
@@ -998,7 +998,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=tf_predictions,
labels=tf_labels,
weights=constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1015,7 +1015,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
def testZeroLossWithOneDimBatchZeroWeights(self):
@@ -1025,7 +1025,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
@@ -1041,7 +1041,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1056,7 +1056,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testLossIsAssociativeAcrossBatchElements(self):
@@ -1087,7 +1087,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=array_ops.concat([predictions0, predictions1], 0),
labels=array_ops.concat([labels0, labels1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1115,7 +1115,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1128,7 +1128,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1136,7 +1136,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1154,7 +1154,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1163,7 +1163,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=constant_op.constant([1, 0, 0]))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1173,12 +1173,12 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testValueErrorThrownWithShapelessPlaceholder(self):
tf_predictions = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=tf_predictions,
@@ -1196,7 +1196,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1206,7 +1206,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3,)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1215,7 +1215,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1228,7 +1228,7 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss = loss_ops.compute_weighted_loss(losses)
self.assertTrue(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1243,7 +1243,7 @@ class AddLossTest(test.TestCase):
loss_ops.add_loss(math_ops.reduce_mean(losses))
self.assertTrue(loss_ops.get_losses())
total_loss = loss_ops.get_total_loss()
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1254,7 +1254,7 @@ class AddLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None)
self.assertFalse(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
def testNoCollectLosses(self):
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 7d26429f9c..9ea94c7433 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -1,62 +1,61 @@
-tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
-tensorflow/tools/proto_text/gen_proto_text_functions.cc
tensorflow/core/framework/resource_handle.cc
+tensorflow/core/lib/core/arena.cc
+tensorflow/core/lib/core/coding.cc
+tensorflow/core/lib/core/status.cc
+tensorflow/core/lib/core/threadpool.cc
+tensorflow/core/lib/hash/crc32c.cc
+tensorflow/core/lib/hash/crc32c_accelerate.cc
+tensorflow/core/lib/hash/hash.cc
+tensorflow/core/lib/histogram/histogram.cc
+tensorflow/core/lib/io/block.cc
+tensorflow/core/lib/io/block_builder.cc
+tensorflow/core/lib/io/buffered_inputstream.cc
+tensorflow/core/lib/io/compression.cc
+tensorflow/core/lib/io/format.cc
+tensorflow/core/lib/io/inputbuffer.cc
+tensorflow/core/lib/io/inputstream_interface.cc
+tensorflow/core/lib/io/iterator.cc
+tensorflow/core/lib/io/path.cc
+tensorflow/core/lib/io/random_inputstream.cc
+tensorflow/core/lib/io/record_reader.cc
+tensorflow/core/lib/io/record_writer.cc
+tensorflow/core/lib/io/table.cc
+tensorflow/core/lib/io/table_builder.cc
+tensorflow/core/lib/io/two_level_iterator.cc
+tensorflow/core/lib/io/zlib_compression_options.cc
+tensorflow/core/lib/io/zlib_inputstream.cc
+tensorflow/core/lib/io/zlib_outputbuffer.cc
+tensorflow/core/lib/random/distribution_sampler.cc
+tensorflow/core/lib/random/random.cc
+tensorflow/core/lib/random/simple_philox.cc
+tensorflow/core/lib/random/weighted_picker.cc
+tensorflow/core/lib/strings/numbers.cc
+tensorflow/core/lib/strings/ordered_code.cc
+tensorflow/core/lib/strings/proto_text_util.cc
+tensorflow/core/lib/strings/scanner.cc
+tensorflow/core/lib/strings/str_util.cc
+tensorflow/core/lib/strings/strcat.cc
+tensorflow/core/lib/strings/stringprintf.cc
+tensorflow/core/lib/wav/wav_io.cc
+tensorflow/core/platform/cpu_info.cc
+tensorflow/core/platform/default/logging.cc
+tensorflow/core/platform/default/mutex.cc
tensorflow/core/platform/default/protobuf.cc
-tensorflow/core/platform/tracing.cc
-tensorflow/core/platform/tensor_coding.cc
-tensorflow/core/platform/protobuf_util.cc
-tensorflow/core/platform/posix/posix_file_system.cc
-tensorflow/core/platform/posix/port.cc
-tensorflow/core/platform/posix/error.cc
-tensorflow/core/platform/posix/env.cc
-tensorflow/core/platform/posix/load_library.cc
-tensorflow/core/platform/posix/env_time.cc
-tensorflow/core/platform/file_system.cc
-tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/default/tracing.cc
+tensorflow/core/platform/denormal.cc
tensorflow/core/platform/env.cc
tensorflow/core/platform/env_time.cc
+tensorflow/core/platform/file_system.cc
+tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/posix/env.cc
+tensorflow/core/platform/posix/env_time.cc
+tensorflow/core/platform/posix/error.cc
+tensorflow/core/platform/posix/load_library.cc
+tensorflow/core/platform/posix/port.cc
+tensorflow/core/platform/posix/posix_file_system.cc
+tensorflow/core/platform/protobuf_util.cc
tensorflow/core/platform/setround.cc
-tensorflow/core/platform/denormal.cc
-tensorflow/core/platform/default/tracing.cc
-tensorflow/core/platform/default/mutex.cc
-tensorflow/core/platform/default/logging.cc
-tensorflow/core/platform/cpu_info.cc
-tensorflow/core/lib/wav/wav_io.cc
-tensorflow/core/lib/strings/stringprintf.cc
-tensorflow/core/lib/strings/strcat.cc
-tensorflow/core/lib/strings/str_util.cc
-tensorflow/core/lib/strings/scanner.cc
-tensorflow/core/lib/strings/proto_text_util.cc
-tensorflow/core/lib/strings/ordered_code.cc
-tensorflow/core/lib/strings/numbers.cc
-tensorflow/core/lib/random/weighted_picker.cc
-tensorflow/core/lib/random/simple_philox.cc
-tensorflow/core/lib/random/random.cc
-tensorflow/core/lib/random/distribution_sampler.cc
-tensorflow/core/lib/io/zlib_outputbuffer.cc
-tensorflow/core/lib/io/zlib_inputstream.cc
-tensorflow/core/lib/io/zlib_compression_options.cc
-tensorflow/core/lib/io/two_level_iterator.cc
-tensorflow/core/lib/io/table_builder.cc
-tensorflow/core/lib/io/table.cc
-tensorflow/core/lib/io/record_writer.cc
-tensorflow/core/lib/io/record_reader.cc
-tensorflow/core/lib/io/random_inputstream.cc
-tensorflow/core/lib/io/path.cc
-tensorflow/core/lib/io/iterator.cc
-tensorflow/core/lib/io/inputstream_interface.cc
-tensorflow/core/lib/io/inputbuffer.cc
-tensorflow/core/lib/io/format.cc
-tensorflow/core/lib/io/compression.cc
-tensorflow/core/lib/io/buffered_inputstream.cc
-tensorflow/core/lib/io/block_builder.cc
-tensorflow/core/lib/io/block.cc
-tensorflow/core/lib/histogram/histogram.cc
-tensorflow/core/lib/hash/hash.cc
-tensorflow/core/lib/hash/crc32c.cc
-tensorflow/core/lib/hash/crc32c_accelerate.cc
-tensorflow/core/lib/core/threadpool.cc
-tensorflow/core/lib/core/stringpiece.cc
-tensorflow/core/lib/core/status.cc
-tensorflow/core/lib/core/coding.cc
-tensorflow/core/lib/core/arena.cc
+tensorflow/core/platform/tensor_coding.cc
+tensorflow/core/platform/tracing.cc
+tensorflow/tools/proto_text/gen_proto_text_functions.cc
+tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 938c4a53ab..1d6d9a60e5 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -1,41 +1,42 @@
-tensorflow/core/util/test_log.pb.cc
-tensorflow/core/util/saved_tensor_slice.pb.cc
-tensorflow/core/util/memmapped_file_system.pb.cc
-tensorflow/core/util/event.pb.cc
-tensorflow/core/protobuf/tensorflow_server.pb.cc
-tensorflow/core/protobuf/saver.pb.cc
-tensorflow/core/protobuf/queue_runner.pb.cc
-tensorflow/core/protobuf/named_tensor.pb.cc
-tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/example/example.pb.cc
+tensorflow/core/example/feature.pb.cc
+tensorflow/core/framework/allocation_description.pb.cc
+tensorflow/core/framework/api_def.pb.cc
+tensorflow/core/framework/attr_value.pb.cc
+tensorflow/core/framework/cost_graph.pb.cc
+tensorflow/core/framework/device_attributes.pb.cc
+tensorflow/core/framework/function.pb.cc
+tensorflow/core/framework/graph.pb.cc
+tensorflow/core/framework/graph_transfer_info.pb.cc
+tensorflow/core/framework/kernel_def.pb.cc
+tensorflow/core/framework/log_memory.pb.cc
+tensorflow/core/framework/model.pb.cc
+tensorflow/core/framework/node_def.pb.cc
+tensorflow/core/framework/op_def.pb.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
+tensorflow/core/framework/resource_handle.pb.cc
+tensorflow/core/framework/step_stats.pb.cc
+tensorflow/core/framework/summary.pb.cc
+tensorflow/core/framework/tensor.pb.cc
+tensorflow/core/framework/tensor_description.pb.cc
+tensorflow/core/framework/tensor_shape.pb.cc
+tensorflow/core/framework/tensor_slice.pb.cc
+tensorflow/core/framework/types.pb.cc
+tensorflow/core/framework/variable.pb.cc
+tensorflow/core/framework/versions.pb.cc
+tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
-tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc
tensorflow/core/protobuf/device_properties.pb.cc
-tensorflow/core/lib/core/error_codes.pb.cc
-tensorflow/core/framework/versions.pb.cc
-tensorflow/core/framework/variable.pb.cc
-tensorflow/core/framework/types.pb.cc
-tensorflow/core/framework/tensor_slice.pb.cc
-tensorflow/core/framework/tensor_shape.pb.cc
-tensorflow/core/framework/tensor_description.pb.cc
-tensorflow/core/framework/tensor.pb.cc
-tensorflow/core/framework/summary.pb.cc
-tensorflow/core/framework/step_stats.pb.cc
-tensorflow/core/framework/resource_handle.pb.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
-tensorflow/core/framework/api_def.pb.cc
-tensorflow/core/framework/op_def.pb.cc
-tensorflow/core/framework/node_def.pb.cc
-tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/kernel_def.pb.cc
-tensorflow/core/framework/graph_transfer_info.pb.cc
-tensorflow/core/framework/graph.pb.cc
-tensorflow/core/framework/function.pb.cc
-tensorflow/core/framework/device_attributes.pb.cc
-tensorflow/core/framework/cost_graph.pb.cc
-tensorflow/core/framework/attr_value.pb.cc
-tensorflow/core/framework/allocation_description.pb.cc
-tensorflow/core/example/feature.pb.cc
-tensorflow/core/example/example.pb.cc
-tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/named_tensor.pb.cc
+tensorflow/core/protobuf/queue_runner.pb.cc
+tensorflow/core/protobuf/rewriter_config.pb.cc
+tensorflow/core/protobuf/saver.pb.cc
+tensorflow/core/protobuf/tensorflow_server.pb.cc
+tensorflow/core/util/event.pb.cc
+tensorflow/core/util/memmapped_file_system.pb.cc
+tensorflow/core/util/saved_tensor_slice.pb.cc
+tensorflow/core/util/test_log.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index aa91b2f954..884461ecae 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -1,42 +1,44 @@
-tensorflow/core/util/test_log.pb.h
-tensorflow/core/util/saved_tensor_slice.pb.h
-tensorflow/core/util/memmapped_file_system.pb.h
-tensorflow/core/util/event.pb.h
-tensorflow/core/protobuf/tensorflow_server.pb.h
-tensorflow/core/protobuf/saver.pb.h
-tensorflow/core/protobuf/queue_runner.pb.h
-tensorflow/core/protobuf/named_tensor.pb.h
-tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/example/example.pb.h
+tensorflow/core/example/feature.pb.h
+tensorflow/core/framework/allocation_description.pb.h
+tensorflow/core/framework/api_def.pb.h
+tensorflow/core/framework/attr_value.pb.h
+tensorflow/core/framework/cost_graph.pb.h
+tensorflow/core/framework/device_attributes.pb.h
+tensorflow/core/framework/function.pb.h
+tensorflow/core/framework/graph.pb.h
+tensorflow/core/framework/graph_transfer_info.pb.h
+tensorflow/core/framework/kernel_def.pb.h
+tensorflow/core/framework/log_memory.pb.h
+tensorflow/core/framework/model.pb.h
+tensorflow/core/framework/node_def.pb.h
+tensorflow/core/framework/op_def.pb.h
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
+tensorflow/core/framework/resource_handle.pb.h
+tensorflow/core/framework/step_stats.pb.h
+tensorflow/core/framework/summary.pb.h
+tensorflow/core/framework/tensor.pb.h
+tensorflow/core/framework/tensor_description.pb.h
+tensorflow/core/framework/tensor_shape.pb.h
+tensorflow/core/framework/tensor_slice.pb.h
+tensorflow/core/framework/types.pb.h
+tensorflow/core/framework/variable.pb.h
+tensorflow/core/framework/versions.pb.h
+tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/device_properties.pb.h
+tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/named_tensor.pb.h
+tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
+tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
-tensorflow/core/lib/core/error_codes.pb.h
-tensorflow/core/framework/versions.pb.h
-tensorflow/core/framework/variable.pb.h
-tensorflow/core/framework/types.pb.h
-tensorflow/core/framework/tensor_slice.pb.h
-tensorflow/core/framework/tensor_shape.pb.h
-tensorflow/core/framework/tensor_description.pb.h
-tensorflow/core/framework/tensor.pb.h
-tensorflow/core/framework/summary.pb.h
-tensorflow/core/framework/step_stats.pb.h
-tensorflow/core/framework/resource_handle.pb.h
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
-tensorflow/core/framework/api_def.pb.h
-tensorflow/core/framework/op_def.pb.h
-tensorflow/core/framework/node_def.pb.h
-tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/kernel_def.pb.h
-tensorflow/core/framework/graph_transfer_info.pb.h
-tensorflow/core/framework/graph.pb.h
-tensorflow/core/framework/function.pb.h
-tensorflow/core/framework/device_attributes.pb.h
-tensorflow/core/framework/cost_graph.pb.h
-tensorflow/core/framework/attr_value.pb.h
-tensorflow/core/framework/allocation_description.pb.h
-tensorflow/core/example/feature.pb.h
-tensorflow/core/example/example.pb.h
-tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/protobuf/tensorflow_server.pb.h
+tensorflow/core/util/event.pb.h
+tensorflow/core/util/memmapped_file_system.pb.h
+tensorflow/core/util/saved_tensor_slice.pb.h
+tensorflow/core/util/test_log.pb.h
+
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 66a3315700..08de54b8e1 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -4,218 +4,19 @@ tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
tensorflow/contrib/boosted_trees/ops/training_ops.cc
-tensorflow/core/kernels/xent_op.cc
-tensorflow/core/kernels/where_op.cc
-tensorflow/core/kernels/variable_ops.cc
-tensorflow/core/kernels/unpack_op.cc
-tensorflow/core/kernels/unique_op.cc
-tensorflow/core/kernels/transpose_op.cc
-tensorflow/core/kernels/transpose_functor_cpu.cc
-tensorflow/core/kernels/training_op_helpers.cc
-tensorflow/core/kernels/training_ops.cc
-tensorflow/core/kernels/topk_op.cc
-tensorflow/core/kernels/tile_functor_cpu.cc
-tensorflow/core/kernels/tile_ops.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
-tensorflow/core/kernels/tensor_array_ops.cc
-tensorflow/core/kernels/tensor_array.cc
-tensorflow/core/kernels/strided_slice_op_inst_7.cc
-tensorflow/core/kernels/strided_slice_op_inst_6.cc
-tensorflow/core/kernels/strided_slice_op_inst_5.cc
-tensorflow/core/kernels/strided_slice_op_inst_4.cc
-tensorflow/core/kernels/strided_slice_op_inst_3.cc
-tensorflow/core/kernels/strided_slice_op_inst_2.cc
-tensorflow/core/kernels/strided_slice_op_inst_1.cc
-tensorflow/core/kernels/strided_slice_op_inst_0.cc
-tensorflow/core/kernels/strided_slice_op.cc
-tensorflow/core/kernels/stack_ops.cc
-tensorflow/core/kernels/split_op.cc
-tensorflow/core/kernels/split_v_op.cc
-tensorflow/core/kernels/split_lib_cpu.cc
-tensorflow/core/kernels/spectrogram_op.cc
-tensorflow/core/kernels/spectrogram.cc
-tensorflow/core/kernels/sparse_to_dense_op.cc
-tensorflow/core/kernels/sparse_matmul_op.cc
-tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
-tensorflow/core/kernels/sparse_reshape_op.c
-tensorflow/core/kernels/segment_reduction_ops.cc
-tensorflow/core/kernels/softsign_op.cc
-tensorflow/core/kernels/softplus_op.cc
-tensorflow/core/kernels/softmax_op.cc
-tensorflow/core/kernels/slice_op_cpu_impl_1.cc
-tensorflow/core/kernels/slice_op_cpu_impl_2.cc
-tensorflow/core/kernels/slice_op_cpu_impl_3.cc
-tensorflow/core/kernels/slice_op_cpu_impl_4.cc
-tensorflow/core/kernels/slice_op_cpu_impl_5.cc
-tensorflow/core/kernels/slice_op_cpu_impl_6.cc
-tensorflow/core/kernels/slice_op_cpu_impl_7.cc
-tensorflow/core/kernels/slice_op.cc
-tensorflow/core/kernels/shape_ops.cc
-tensorflow/core/kernels/session_ops.cc
-tensorflow/core/kernels/sequence_ops.cc
-tensorflow/core/kernels/sendrecv_ops.cc
-tensorflow/core/kernels/scatter_op.cc
-tensorflow/core/kernels/scatter_functor.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/scatter_nd_op.cc
-tensorflow/core/kernels/save_restore_tensor.cc
-tensorflow/core/kernels/save_restore_v2_ops.cc
-tensorflow/core/kernels/save_op.cc
-tensorflow/core/kernels/string_join_op.cc
-tensorflow/core/kernels/reverse_sequence_op.cc
-tensorflow/core/kernels/reverse_op.cc
-tensorflow/core/kernels/restore_op.cc
-tensorflow/core/kernels/resize_nearest_neighbor_op.cc
-tensorflow/core/kernels/resize_bilinear_op.cc
-tensorflow/core/kernels/reshape_util.cc
-tensorflow/core/kernels/reshape_op.cc
-tensorflow/core/kernels/relu_op.cc
-tensorflow/core/kernels/reduction_ops_sum.cc
-tensorflow/core/kernels/reduction_ops_prod.cc
-tensorflow/core/kernels/reduction_ops_min.cc
-tensorflow/core/kernels/reduction_ops_mean.cc
-tensorflow/core/kernels/reduction_ops_max.cc
-tensorflow/core/kernels/reduction_ops_common.cc
-tensorflow/core/kernels/reduction_ops_any.cc
-tensorflow/core/kernels/reduction_ops_all.cc
-tensorflow/core/kernels/roll_op.cc
-tensorflow/core/kernels/queue_op.cc
-tensorflow/core/kernels/queue_ops.cc
-tensorflow/core/kernels/queue_base.cc
-tensorflow/core/kernels/pooling_ops_common.cc
-tensorflow/core/kernels/padding_fifo_queue_op.cc
-tensorflow/core/kernels/padding_fifo_queue.cc
-tensorflow/core/kernels/pad_op.cc
-tensorflow/core/kernels/pack_op.cc
-tensorflow/core/kernels/ops_util.cc
-tensorflow/core/kernels/one_hot_op.cc
-tensorflow/core/kernels/non_max_suppression_op.cc
-tensorflow/core/kernels/no_op.cc
-tensorflow/core/kernels/mirror_pad_op.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
-tensorflow/core/kernels/mfcc_op.cc
-tensorflow/core/kernels/mfcc_mel_filterbank.cc
-tensorflow/core/kernels/mfcc_dct.cc
-tensorflow/core/kernels/mfcc.cc
-tensorflow/core/kernels/maxpooling_op.cc
-tensorflow/core/kernels/matmul_op.cc
-tensorflow/core/kernels/lrn_op.cc
-tensorflow/core/kernels/logging_ops.cc
-tensorflow/core/kernels/initializable_lookup_table.c
-tensorflow/core/kernels/lookup_table_init_op.cc
-tensorflow/core/kernels/lookup_table_op.cc
-tensorflow/core/kernels/lookup_util.cc
-tensorflow/core/kernels/inplace_ops.cc
-tensorflow/core/kernels/in_topk_op.cc
-tensorflow/core/kernels/immutable_constant_op.cc
-tensorflow/core/kernels/identity_op.cc
-tensorflow/core/kernels/identity_n_op.cc
-tensorflow/core/kernels/gather_op.cc
-tensorflow/core/kernels/gather_functor.cc
-tensorflow/core/kernels/gather_nd_op.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/fused_batch_norm_op.cc
-tensorflow/core/kernels/function_ops.cc
-tensorflow/core/kernels/fill_functor.cc
-tensorflow/core/kernels/fifo_queue.cc
-tensorflow/core/kernels/fifo_queue_op.cc
-tensorflow/core/kernels/fake_quant_ops.cc
-tensorflow/core/kernels/example_parsing_ops.cc
-tensorflow/core/kernels/encode_wav_op.cc
-tensorflow/core/kernels/dynamic_stitch_op.cc
-tensorflow/core/kernels/dynamic_partition_op.cc
-tensorflow/core/kernels/decode_bmp_op.cc
-tensorflow/core/kernels/depthtospace_op.cc
-tensorflow/core/kernels/data_format_ops.cc
-tensorflow/core/kernels/spacetodepth_op.cc
-tensorflow/core/kernels/dense_update_functor.cc
-tensorflow/core/kernels/dense_update_ops.cc
-tensorflow/core/kernels/deep_conv2d.cc
-tensorflow/core/kernels/decode_wav_op.cc
-tensorflow/core/kernels/xsmm_conv2d.cc
-tensorflow/core/kernels/cwise_ops_common.cc
-tensorflow/core/kernels/cwise_op_tanh.cc
-tensorflow/core/kernels/cwise_op_pow.cc
-tensorflow/core/kernels/cwise_op_sub.cc
-tensorflow/core/kernels/cwise_op_squared_difference.cc
-tensorflow/core/kernels/cwise_op_square.cc
-tensorflow/core/kernels/cwise_op_sqrt.cc
-tensorflow/core/kernels/cwise_op_sigmoid.cc
-tensorflow/core/kernels/cwise_op_sign.cc
-tensorflow/core/kernels/cwise_op_select.cc
-tensorflow/core/kernels/cwise_op_round.cc
-tensorflow/core/kernels/cwise_op_rsqrt.cc
-tensorflow/core/kernels/cwise_op_reciprocal.cc
-tensorflow/core/kernels/cwise_op_neg.cc
-tensorflow/core/kernels/cwise_op_mul_2.cc
-tensorflow/core/kernels/cwise_op_mul_1.cc
-tensorflow/core/kernels/cwise_op_minimum.cc
-tensorflow/core/kernels/cwise_op_maximum.cc
-tensorflow/core/kernels/cwise_op_logical_not.cc
-tensorflow/core/kernels/cwise_op_logical_and.cc
-tensorflow/core/kernels/cwise_op_logical_or.cc
-tensorflow/core/kernels/cwise_op_log.cc
-tensorflow/core/kernels/cwise_op_less.cc
-tensorflow/core/kernels/cwise_op_less_equal.cc
-tensorflow/core/kernels/cwise_op_isnan.cc
-tensorflow/core/kernels/cwise_op_isfinite.cc
-tensorflow/core/kernels/cwise_op_invert.cc
-tensorflow/core/kernels/cwise_op_greater_equal.cc
-tensorflow/core/kernels/cwise_op_greater.cc
-tensorflow/core/kernels/cwise_op_floor_div.cc
-tensorflow/core/kernels/cwise_op_floor_mod.cc
-tensorflow/core/kernels/cwise_op_floor.cc
-tensorflow/core/kernels/cwise_op_exp.cc
-tensorflow/core/kernels/cwise_op_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_div.cc
-tensorflow/core/kernels/cwise_op_bitwise_xor.cc
-tensorflow/core/kernels/cwise_op_bitwise_or.cc
-tensorflow/core/kernels/cwise_op_bitwise_and.cc
-tensorflow/core/kernels/cwise_op_left_shift.cc
-tensorflow/core/kernels/cwise_op_right_shift.cc
-tensorflow/core/kernels/cwise_op_add_2.cc
-tensorflow/core/kernels/cwise_op_add_1.cc
-tensorflow/core/kernels/cwise_op_abs.cc
-tensorflow/core/kernels/ctc_decoder_ops.cc
-tensorflow/core/kernels/crop_and_resize_op.cc
-tensorflow/core/kernels/conv_ops_using_gemm.cc
-tensorflow/core/kernels/conv_ops_fused.cc
-tensorflow/core/kernels/conv_ops.cc
-tensorflow/core/kernels/conv_grad_filter_ops.cc
-tensorflow/core/kernels/conv_grad_input_ops.cc
-tensorflow/core/kernels/conv_grad_ops.cc
-tensorflow/core/kernels/control_flow_ops.cc
-tensorflow/core/kernels/constant_op.cc
-tensorflow/core/kernels/concat_op.cc
-tensorflow/core/kernels/concat_lib_cpu.cc
-tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/argmax_op.cc
+tensorflow/core/kernels/avgpooling_op.cc
+tensorflow/core/kernels/batch_matmul_op_real.cc
+tensorflow/core/kernels/batch_norm_op.cc
+tensorflow/core/kernels/batchtospace_op.cc
+tensorflow/core/kernels/bcast_ops.cc
+tensorflow/core/kernels/bias_op.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
tensorflow/core/kernels/cast_op.cc
tensorflow/core/kernels/cast_op_impl_bfloat.cc
tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -232,20 +33,131 @@ tensorflow/core/kernels/cast_op_impl_uint16.cc
tensorflow/core/kernels/cast_op_impl_uint32.cc
tensorflow/core/kernels/cast_op_impl_uint64.cc
tensorflow/core/kernels/cast_op_impl_uint8.cc
-tensorflow/core/kernels/boosted_trees/prediction_ops.cc
-tensorflow/core/kernels/boosted_trees/resource_ops.cc
-tensorflow/core/kernels/boosted_trees/resources.cc
-tensorflow/core/kernels/boosted_trees/stats_ops.cc
-tensorflow/core/kernels/boosted_trees/training_ops.cc
-tensorflow/core/kernels/bias_op.cc
-tensorflow/core/kernels/bcast_ops.cc
-tensorflow/core/kernels/batch_norm_op.cc
-tensorflow/core/kernels/avgpooling_op.cc
-tensorflow/core/kernels/argmax_op.cc
-tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/concat_lib_cpu.cc
+tensorflow/core/kernels/concat_op.cc
+tensorflow/core/kernels/constant_op.cc
+tensorflow/core/kernels/control_flow_ops.cc
+tensorflow/core/kernels/conv_grad_filter_ops.cc
+tensorflow/core/kernels/conv_grad_input_ops.cc
+tensorflow/core/kernels/conv_grad_ops.cc
+tensorflow/core/kernels/conv_ops.cc
+tensorflow/core/kernels/conv_ops_fused.cc
+tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/crop_and_resize_op.cc
+tensorflow/core/kernels/ctc_decoder_ops.cc
+tensorflow/core/kernels/cwise_op_abs.cc
+tensorflow/core/kernels/cwise_op_add_1.cc
+tensorflow/core/kernels/cwise_op_add_2.cc
+tensorflow/core/kernels/cwise_op_bitwise_and.cc
+tensorflow/core/kernels/cwise_op_bitwise_or.cc
+tensorflow/core/kernels/cwise_op_bitwise_xor.cc
+tensorflow/core/kernels/cwise_op_div.cc
+tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_exp.cc
+tensorflow/core/kernels/cwise_op_floor.cc
+tensorflow/core/kernels/cwise_op_floor_div.cc
+tensorflow/core/kernels/cwise_op_floor_mod.cc
+tensorflow/core/kernels/cwise_op_greater.cc
+tensorflow/core/kernels/cwise_op_greater_equal.cc
+tensorflow/core/kernels/cwise_op_invert.cc
+tensorflow/core/kernels/cwise_op_isfinite.cc
+tensorflow/core/kernels/cwise_op_isnan.cc
+tensorflow/core/kernels/cwise_op_left_shift.cc
+tensorflow/core/kernels/cwise_op_less.cc
+tensorflow/core/kernels/cwise_op_less_equal.cc
+tensorflow/core/kernels/cwise_op_log.cc
+tensorflow/core/kernels/cwise_op_logical_and.cc
+tensorflow/core/kernels/cwise_op_logical_not.cc
+tensorflow/core/kernels/cwise_op_logical_or.cc
+tensorflow/core/kernels/cwise_op_maximum.cc
+tensorflow/core/kernels/cwise_op_minimum.cc
+tensorflow/core/kernels/cwise_op_mul_1.cc
+tensorflow/core/kernels/cwise_op_mul_2.cc
+tensorflow/core/kernels/cwise_op_neg.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_pow.cc
+tensorflow/core/kernels/cwise_op_reciprocal.cc
+tensorflow/core/kernels/cwise_op_right_shift.cc
+tensorflow/core/kernels/cwise_op_round.cc
+tensorflow/core/kernels/cwise_op_rsqrt.cc
+tensorflow/core/kernels/cwise_op_select.cc
+tensorflow/core/kernels/cwise_op_sigmoid.cc
+tensorflow/core/kernels/cwise_op_sign.cc
+tensorflow/core/kernels/cwise_op_sqrt.cc
+tensorflow/core/kernels/cwise_op_square.cc
+tensorflow/core/kernels/cwise_op_squared_difference.cc
+tensorflow/core/kernels/cwise_op_sub.cc
+tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_ops_common.cc
+tensorflow/core/kernels/data_format_ops.cc
+tensorflow/core/kernels/decode_bmp_op.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/decode_wav_op.cc
+tensorflow/core/kernels/deep_conv2d.cc
+tensorflow/core/kernels/dense_update_functor.cc
+tensorflow/core/kernels/dense_update_ops.cc
+tensorflow/core/kernels/depthtospace_op.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/core/kernels/dequantize_op.cc
+tensorflow/core/kernels/dynamic_partition_op.cc
+tensorflow/core/kernels/dynamic_stitch_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/kernels/encode_wav_op.cc
+tensorflow/core/kernels/example_parsing_ops.cc
+tensorflow/core/kernels/fake_quant_ops.cc
+tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
+tensorflow/core/kernels/fill_functor.cc
+tensorflow/core/kernels/function_ops.cc
+tensorflow/core/kernels/fused_batch_norm_op.cc
+tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/gather_op.cc
+tensorflow/core/kernels/identity_n_op.cc
+tensorflow/core/kernels/identity_op.cc
+tensorflow/core/kernels/immutable_constant_op.cc
+tensorflow/core/kernels/in_topk_op.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/inplace_ops.cc
+tensorflow/core/kernels/listdiff_op.cc
+tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
+tensorflow/core/kernels/lrn_op.cc
+tensorflow/core/kernels/matmul_op.cc
+tensorflow/core/kernels/maxpooling_op.cc
tensorflow/core/kernels/meta_support.cc
+tensorflow/core/kernels/mfcc.cc
+tensorflow/core/kernels/mfcc_dct.cc
+tensorflow/core/kernels/mfcc_mel_filterbank.cc
+tensorflow/core/kernels/mfcc_op.cc
+tensorflow/core/kernels/mirror_pad_op.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
+tensorflow/core/kernels/no_op.cc
+tensorflow/core/kernels/non_max_suppression_op.cc
+tensorflow/core/kernels/one_hot_op.cc
+tensorflow/core/kernels/ops_util.cc
+tensorflow/core/kernels/pack_op.cc
+tensorflow/core/kernels/pad_op.cc
+tensorflow/core/kernels/padding_fifo_queue.cc
+tensorflow/core/kernels/padding_fifo_queue_op.cc
+tensorflow/core/kernels/pooling_ops_common.cc
tensorflow/core/kernels/population_count_op.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
@@ -262,46 +174,135 @@ tensorflow/core/kernels/quantized_mul_op.cc
tensorflow/core/kernels/quantized_pooling_ops.cc
tensorflow/core/kernels/quantized_reshape_op.cc
tensorflow/core/kernels/quantized_resize_bilinear_op.cc
-tensorflow/core/kernels/requantization_range_op.cc
-tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/queue_base.cc
+tensorflow/core/kernels/queue_op.cc
+tensorflow/core/kernels/queue_ops.cc
+tensorflow/core/kernels/random_op.cc
+tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/reduction_ops_any.cc
+tensorflow/core/kernels/reduction_ops_common.cc
+tensorflow/core/kernels/reduction_ops_max.cc
+tensorflow/core/kernels/reduction_ops_mean.cc
+tensorflow/core/kernels/reduction_ops_min.cc
+tensorflow/core/kernels/reduction_ops_prod.cc
+tensorflow/core/kernels/reduction_ops_sum.cc
+tensorflow/core/kernels/relu_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
-tensorflow/core/kernels/batch_matmul_op_real.cc
-tensorflow/core/kernels/random_op.cc
-tensorflow/core/ops/training_ops.cc
-tensorflow/core/ops/string_ops.cc
-tensorflow/core/ops/state_ops.cc
-tensorflow/core/ops/sparse_ops.cc
-tensorflow/core/ops/sendrecv_ops.cc
-tensorflow/core/ops/script_ops.cc
-tensorflow/core/ops/remote_fused_graph_ops.cc
-tensorflow/core/ops/random_ops.cc
-tensorflow/core/ops/random_grad.cc
-tensorflow/core/ops/parsing_ops.cc
-tensorflow/core/ops/no_op.cc
-tensorflow/core/ops/nn_ops.cc
-tensorflow/core/ops/nn_grad.cc
-tensorflow/core/ops/manip_ops.cc
-tensorflow/core/ops/math_ops.cc
-tensorflow/core/ops/math_grad.cc
-tensorflow/core/ops/logging_ops.cc
-tensorflow/core/ops/linalg_ops.cc
-tensorflow/core/ops/io_ops.cc
-tensorflow/core/ops/image_ops.cc
-tensorflow/core/ops/functional_ops.cc
-tensorflow/core/ops/functional_grad.cc
-tensorflow/core/ops/function_ops.cc
-tensorflow/core/ops/data_flow_ops.cc
-tensorflow/core/ops/ctc_ops.cc
-tensorflow/core/ops/control_flow_ops.cc
-tensorflow/core/ops/candidate_sampling_ops.cc
-tensorflow/core/ops/boosted_trees_ops.cc
-tensorflow/core/ops/array_ops.cc
-tensorflow/core/ops/array_grad.cc
+tensorflow/core/kernels/requantization_range_op.cc
+tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/reshape_op.cc
+tensorflow/core/kernels/reshape_util.cc
+tensorflow/core/kernels/resize_bilinear_op.cc
+tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+tensorflow/core/kernels/restore_op.cc
+tensorflow/core/kernels/reverse_op.cc
+tensorflow/core/kernels/reverse_sequence_op.cc
+tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/save_op.cc
+tensorflow/core/kernels/save_restore_tensor.cc
+tensorflow/core/kernels/save_restore_v2_ops.cc
+tensorflow/core/kernels/scatter_functor.cc
+tensorflow/core/kernels/scatter_nd_op.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/scatter_op.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/sendrecv_ops.cc
+tensorflow/core/kernels/sequence_ops.cc
+tensorflow/core/kernels/session_ops.cc
+tensorflow/core/kernels/shape_ops.cc
+tensorflow/core/kernels/slice_op.cc
+tensorflow/core/kernels/slice_op_cpu_impl_1.cc
+tensorflow/core/kernels/slice_op_cpu_impl_2.cc
+tensorflow/core/kernels/slice_op_cpu_impl_3.cc
+tensorflow/core/kernels/slice_op_cpu_impl_4.cc
+tensorflow/core/kernels/slice_op_cpu_impl_5.cc
+tensorflow/core/kernels/slice_op_cpu_impl_6.cc
+tensorflow/core/kernels/slice_op_cpu_impl_7.cc
+tensorflow/core/kernels/softmax_op.cc
+tensorflow/core/kernels/softplus_op.cc
+tensorflow/core/kernels/softsign_op.cc
tensorflow/core/kernels/spacetobatch_functor.cc
tensorflow/core/kernels/spacetobatch_op.cc
-tensorflow/core/kernels/batchtospace_op.cc
-tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/sparse_to_dense_op.cc
+tensorflow/core/kernels/spectrogram.cc
+tensorflow/core/kernels/spectrogram_op.cc
+tensorflow/core/kernels/split_lib_cpu.cc
+tensorflow/core/kernels/split_op.cc
+tensorflow/core/kernels/split_v_op.cc
+tensorflow/core/kernels/stack_ops.cc
+tensorflow/core/kernels/strided_slice_op.cc
+tensorflow/core/kernels/strided_slice_op_inst_0.cc
+tensorflow/core/kernels/strided_slice_op_inst_1.cc
+tensorflow/core/kernels/strided_slice_op_inst_2.cc
+tensorflow/core/kernels/strided_slice_op_inst_3.cc
+tensorflow/core/kernels/strided_slice_op_inst_4.cc
+tensorflow/core/kernels/strided_slice_op_inst_5.cc
+tensorflow/core/kernels/strided_slice_op_inst_6.cc
+tensorflow/core/kernels/strided_slice_op_inst_7.cc
+tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/tensor_array.cc
+tensorflow/core/kernels/tensor_array_ops.cc
+tensorflow/core/kernels/tile_functor_cpu.cc
+tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
+tensorflow/core/kernels/topk_op.cc
+tensorflow/core/kernels/training_op_helpers.cc
+tensorflow/core/kernels/training_ops.cc
+tensorflow/core/kernels/transpose_functor_cpu.cc
+tensorflow/core/kernels/transpose_op.cc
+tensorflow/core/kernels/unique_op.cc
+tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/variable_ops.cc
+tensorflow/core/kernels/where_op.cc
+tensorflow/core/kernels/xent_op.cc
+tensorflow/core/kernels/xsmm_conv2d.cc
+tensorflow/core/ops/array_grad.cc
+tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/audio_ops.cc
-tensorflow/core/kernels/decode_proto_op.cc
-tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/ops/boosted_trees_ops.cc
+tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/control_flow_ops.cc
+tensorflow/core/ops/ctc_ops.cc
+tensorflow/core/ops/data_flow_ops.cc
+tensorflow/core/ops/function_ops.cc
+tensorflow/core/ops/functional_grad.cc
+tensorflow/core/ops/functional_ops.cc
+tensorflow/core/ops/image_ops.cc
+tensorflow/core/ops/io_ops.cc
+tensorflow/core/ops/linalg_ops.cc
+tensorflow/core/ops/logging_ops.cc
+tensorflow/core/ops/manip_ops.cc
+tensorflow/core/ops/math_grad.cc
+tensorflow/core/ops/math_ops.cc
+tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/nn_ops.cc
+tensorflow/core/ops/no_op.cc
+tensorflow/core/ops/parsing_ops.cc
+tensorflow/core/ops/random_grad.cc
+tensorflow/core/ops/random_ops.cc
+tensorflow/core/ops/remote_fused_graph_ops.cc
+tensorflow/core/ops/script_ops.cc
+tensorflow/core/ops/sendrecv_ops.cc
+tensorflow/core/ops/sparse_ops.cc
+tensorflow/core/ops/state_ops.cc
+tensorflow/core/ops/string_ops.cc
+tensorflow/core/ops/training_ops.cc
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index b5431df2eb..e23f499214 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,33 +1,34 @@
-tensorflow/core/util/saved_tensor_slice.pb_text.cc
-tensorflow/core/util/memmapped_file_system.pb_text.cc
-tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/example/example.pb_text.cc
+tensorflow/core/example/feature.pb_text.cc
+tensorflow/core/framework/allocation_description.pb_text.cc
+tensorflow/core/framework/api_def.pb_text.cc
+tensorflow/core/framework/attr_value.pb_text.cc
+tensorflow/core/framework/cost_graph.pb_text.cc
+tensorflow/core/framework/device_attributes.pb_text.cc
+tensorflow/core/framework/function.pb_text.cc
+tensorflow/core/framework/graph.pb_text.cc
+tensorflow/core/framework/graph_transfer_info.pb_text.cc
+tensorflow/core/framework/kernel_def.pb_text.cc
+tensorflow/core/framework/log_memory.pb_text.cc
+tensorflow/core/framework/model.pb_text.cc
+tensorflow/core/framework/node_def.pb_text.cc
+tensorflow/core/framework/op_def.pb_text.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
+tensorflow/core/framework/resource_handle.pb_text.cc
+tensorflow/core/framework/step_stats.pb_text.cc
+tensorflow/core/framework/summary.pb_text.cc
+tensorflow/core/framework/tensor.pb_text.cc
+tensorflow/core/framework/tensor_description.pb_text.cc
+tensorflow/core/framework/tensor_shape.pb_text.cc
+tensorflow/core/framework/tensor_slice.pb_text.cc
+tensorflow/core/framework/types.pb_text.cc
+tensorflow/core/framework/versions.pb_text.cc
+tensorflow/core/lib/core/error_codes.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc
+tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
-tensorflow/core/lib/core/error_codes.pb_text.cc
-tensorflow/core/framework/versions.pb_text.cc
-tensorflow/core/framework/types.pb_text.cc
-tensorflow/core/framework/tensor_slice.pb_text.cc
-tensorflow/core/framework/tensor_shape.pb_text.cc
-tensorflow/core/framework/tensor_description.pb_text.cc
-tensorflow/core/framework/tensor.pb_text.cc
-tensorflow/core/framework/summary.pb_text.cc
-tensorflow/core/framework/step_stats.pb_text.cc
-tensorflow/core/framework/resource_handle.pb_text.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
-tensorflow/core/framework/api_def.pb_text.cc
-tensorflow/core/framework/op_def.pb_text.cc
-tensorflow/core/framework/node_def.pb_text.cc
-tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/kernel_def.pb_text.cc
-tensorflow/core/framework/graph_transfer_info.pb_text.cc
-tensorflow/core/framework/graph.pb_text.cc
-tensorflow/core/framework/function.pb_text.cc
-tensorflow/core/framework/device_attributes.pb_text.cc
-tensorflow/core/framework/cost_graph.pb_text.cc
-tensorflow/core/framework/attr_value.pb_text.cc
-tensorflow/core/framework/allocation_description.pb_text.cc
-tensorflow/core/example/feature.pb_text.cc
-tensorflow/core/example/example.pb_text.cc
+tensorflow/core/util/memmapped_file_system.pb_text.cc
+tensorflow/core/util/saved_tensor_slice.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 1f254692d7..5eae845d9b 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -2,47 +2,48 @@ tensorflow/contrib/boosted_trees/proto/learner.proto
tensorflow/contrib/boosted_trees/proto/quantiles.proto
tensorflow/contrib/boosted_trees/proto/split_info.proto
tensorflow/contrib/boosted_trees/proto/tree_config.proto
-tensorflow/core/util/test_log.proto
-tensorflow/core/util/saved_tensor_slice.proto
-tensorflow/core/util/memmapped_file_system.proto
-tensorflow/core/util/event.proto
-tensorflow/core/protobuf/tensorflow_server.proto
-tensorflow/core/protobuf/saver.proto
-tensorflow/core/protobuf/queue_runner.proto
-tensorflow/core/protobuf/named_tensor.proto
-tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/example/example.proto
+tensorflow/core/example/feature.proto
+tensorflow/core/framework/allocation_description.proto
+tensorflow/core/framework/api_def.proto
+tensorflow/core/framework/attr_value.proto
+tensorflow/core/framework/cost_graph.proto
+tensorflow/core/framework/device_attributes.proto
+tensorflow/core/framework/function.proto
+tensorflow/core/framework/graph.proto
+tensorflow/core/framework/graph_transfer_info.proto
+tensorflow/core/framework/kernel_def.proto
+tensorflow/core/framework/log_memory.proto
+tensorflow/core/framework/model.proto
+tensorflow/core/framework/node_def.proto
+tensorflow/core/framework/op_def.proto
+tensorflow/core/framework/reader_base.proto
+tensorflow/core/framework/remote_fused_graph_execute_info.proto
+tensorflow/core/framework/resource_handle.proto
+tensorflow/core/framework/step_stats.proto
+tensorflow/core/framework/summary.proto
+tensorflow/core/framework/tensor.proto
+tensorflow/core/framework/tensor_description.proto
+tensorflow/core/framework/tensor_shape.proto
+tensorflow/core/framework/tensor_slice.proto
+tensorflow/core/framework/types.proto
+tensorflow/core/framework/variable.proto
+tensorflow/core/framework/versions.proto
+tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+tensorflow/core/lib/core/error_codes.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/device_properties.proto
+tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/named_tensor.proto
+tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/rewriter_config.proto
+tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/tensor_bundle.proto
-tensorflow/core/lib/core/error_codes.proto
-tensorflow/core/kernels/boosted_trees/boosted_trees.proto
-tensorflow/core/framework/versions.proto
-tensorflow/core/framework/variable.proto
-tensorflow/core/framework/types.proto
-tensorflow/core/framework/tensor_slice.proto
-tensorflow/core/framework/tensor_shape.proto
-tensorflow/core/framework/tensor_description.proto
-tensorflow/core/framework/tensor.proto
-tensorflow/core/framework/summary.proto
-tensorflow/core/framework/step_stats.proto
-tensorflow/core/framework/resource_handle.proto
-tensorflow/core/framework/remote_fused_graph_execute_info.proto
-tensorflow/core/framework/reader_base.proto
-tensorflow/core/framework/api_def.proto
-tensorflow/core/framework/op_def.proto
-tensorflow/core/framework/node_def.proto
-tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/kernel_def.proto
-tensorflow/core/framework/graph_transfer_info.proto
-tensorflow/core/framework/graph.proto
-tensorflow/core/framework/function.proto
-tensorflow/core/framework/device_attributes.proto
-tensorflow/core/framework/cost_graph.proto
-tensorflow/core/framework/attr_value.proto
-tensorflow/core/framework/allocation_description.proto
-tensorflow/core/example/feature.proto
-tensorflow/core/example/example.proto
-tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/protobuf/tensorflow_server.proto
+tensorflow/core/util/event.proto
+tensorflow/core/util/memmapped_file_system.proto
+tensorflow/core/util/saved_tensor_slice.proto
+tensorflow/core/util/test_log.proto
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index c35e60a554..b1c852c2c6 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -31,6 +31,7 @@ from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _graph_util
from tensorflow.python.framework import importer as _importer
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.saved_model import constants as _saved_model_constants
from tensorflow.python.training import saver as _saver_lib
from tensorflow.python.util import compat as _compat
@@ -476,6 +477,12 @@ def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
collection.bytes_list.value[:] = [
s for s in base_collection.bytes_list.value
if not _is_removed_mentioned(s, removed_op_names)]
+ _logging.info(
+ 'In collection %s, nodes excluded are: %s', collection_name,
+ sorted([
+ s for s in base_collection.bytes_list.value
+ if _is_removed_mentioned(s, removed_op_names)
+ ]))
elif base_collection.HasField('node_list'):
collection.node_list.value[:] = [
s for s in base_collection.node_list.value
@@ -745,6 +752,9 @@ def meta_graph_transform(
retained_op_names = [_compat.as_str(node.name)
for node in meta_graph_def.graph_def.node]
removed_op_names = set(base_op_names) - set(retained_op_names)
+ _logging.info('Node names in base graph: %s', sorted(base_op_names))
+ _logging.info('Node names retained: %s', sorted(retained_op_names))
+ _logging.info('Node names removed: %s', sorted(removed_op_names))
# Copy saver, excluding any pruned nodes if graph was not frozen.
# TODO(b/63447631): Revisit this once the problem is addressed. Currently
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
index 7acfc383eb..5777e64c29 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
# code used float32 for accumulation.
num_updates = 71
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_updates):
sess.run(update_op)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 024bd54912..955b83b44d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2)
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
sess.run(variables.local_variables_initializer())
@@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights)
@@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights_placeholder)
@@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(1, tp_update_op.eval())
@@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(37.0, tp_update_op.eval())
@@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(2, fn_update_op.eval())
@@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(8.0, fn_update_op.eval())
@@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(4, fp_update_op.eval())
@@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase):
weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
29.0, 31.0)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(42.0, fp_update_op.eval())
@@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(5, tn_update_op.eval())
@@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(15.0, tn_update_op.eval())
@@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
@@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fpr.eval())
@@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 2.0 + 5.0
weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 1.0 + 3.0
weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
@@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fpr.eval())
@@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase):
labels = array_ops.ones((1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fnr.eval())
@@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 4.0
weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
@@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fnr.eval())
@@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase):
labels = array_ops.zeros((1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase):
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
@@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testPredictionsOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve)
@@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
def testAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
@@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase):
np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
for _ in xrange(10):
@@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testAllLabelsOnes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([1, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testAllLabelsZeros(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([0, 0, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testNonZeroOnePredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
@@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0, 1, 0])
labels = constant_op.constant([0, 1, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testExceptionOnIncompatibleShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([5])
labels = array_ops.zeros([6])
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
@@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval())
def testAUCPRReverseIncreasingPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1])
@@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
def testAUCPRJumbledPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
@@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
def testAUCPRPredictionsLessThanHalf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase):
auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
tf_predictions,
weights=tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
expected_result: The expected result (dict) that maps to tensors.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
@@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
dtype=dtypes_lib.float32)
auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertAllClose(expected_auc, auc.auc.eval())
def testExceptionOnFloatLabels(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([0.7, 0, 1, 0, 1])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertRaises(TypeError, sess.run(update_op))
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
result, update_op = metric_ops.precision_recall_at_equal_thresholds(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run several updates.
sess.run(variables.local_variables_initializer())
for _ in range(3):
@@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
default from assertAllClose.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(predictions, dtype=dtype)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
weights_tensor = None
@@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertEqual(0, fpr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fpr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=1.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, recall.eval())
@@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, recall.eval())
@@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3461,7 +3461,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, weights=weights, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3486,7 +3486,7 @@ class RecallAtPrecisionTest(test.TestCase):
precision=target_precision,
strict_mode=strict_mode)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(expected_recall, sess.run(update_op))
self.assertAlmostEqual(expected_recall, recall.eval())
@@ -3565,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3585,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, precision.eval())
@@ -3599,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(sess.run(label_prior), sess.run(update_op))
self.assertEqual(sess.run(label_prior), precision.eval())
@@ -3614,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, precision.eval())
@@ -3629,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(2.0/3, sess.run(update_op))
self.assertAlmostEqual(2.0/3, precision.eval())
@@ -3648,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(34.0/43, sess.run(update_op))
self.assertAlmostEqual(34.0/43, precision.eval())
@@ -3697,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3712,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3725,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertEqual(0, fnr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3741,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3754,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fnr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3774,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3794,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3812,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3852,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3940,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.25, sess.run(update_op))
self.assertEqual(0.25, recall.eval())
@@ -3958,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
@@ -3976,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -4000,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase):
k=2,
weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -4122,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
self.assertAlmostEqual(expected, metric.eval())
def test_top_k_rank_invalid(self):
- with self.test_session():
+ with self.cached_session():
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
@@ -4669,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
expected_precision = 0.5
- with self.test_session():
+ with self.cached_session():
_, precision = metrics.streaming_sparse_precision_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5374,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
- with self.test_session():
+ with self.cached_session():
_, recall = metrics.streaming_sparse_recall_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5418,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5440,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -5484,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5509,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -5525,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -5563,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5581,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -5594,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -5609,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5640,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5683,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5745,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_root_mean_squared_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5758,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -5772,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5786,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5842,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase):
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
cov, update_op = metrics.streaming_covariance(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5855,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertEqual(initial_cov, cov.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5867,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5881,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5899,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5933,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase):
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6023,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase):
pearson_r, update_op = metrics.streaming_pearson_correlation(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6036,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertEqual(initial_r, pearson_r.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -6049,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -6064,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = np.array([2, 4, 6, 8])
labels = np.array([1, 3, 2, 7])
weights = np.array([0, 1, 3, 1])
@@ -6085,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6120,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6162,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6243,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6266,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6283,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -6305,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -6324,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6343,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -6378,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -6398,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -6475,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase):
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6489,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6521,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6569,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -6611,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[7])
], 0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6624,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6635,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6657,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -6672,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6684,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6698,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6733,7 +6733,7 @@ class StreamingConcatTest(test.TestCase):
def testNextArraySize(self):
next_array_size = metric_ops._next_array_size # pylint: disable=protected-access
- with self.test_session():
+ with self.cached_session():
self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
@@ -6741,7 +6741,7 @@ class StreamingConcatTest(test.TestCase):
self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
def testStreamingConcat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6758,7 +6758,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual(np.arange(10), concatenated.eval())
def testStreamingConcatStringValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.string, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6777,7 +6777,7 @@ class StreamingConcatTest(test.TestCase):
concatenated.eval())
def testStreamingConcatMaxSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = math_ops.range(3)
concatenated, update_op = metrics.streaming_concat(values, max_size=5)
sess.run(variables.local_variables_initializer())
@@ -6794,7 +6794,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
def testStreamingConcat2D(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.reshape(math_ops.range(3), (3, 1))
concatenated, update_op = metrics.streaming_concat(values, axis=-1)
sess.run(variables.local_variables_initializer())
@@ -6817,7 +6817,7 @@ class StreamingConcatTest(test.TestCase):
array_ops.placeholder(dtypes_lib.float32, [None, None]))
def testStreamingConcatReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6845,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean(values))
self.assertEqual(len(value_tensors), 1)
self.assertEqual(len(update_ops), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, update_ops[0].eval())
self.assertEqual(1, value_tensors[0].eval())
@@ -6858,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean_squared_error(predictions, labels))
self.assertEqual(len(value_tensors), 2)
self.assertEqual(len(update_ops), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, update_ops[0].eval())
self.assertEqual(4, update_ops[1].eval())
@@ -6879,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(2, len(names_to_values))
self.assertEqual(2, len(names_to_updates))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, names_to_updates['m1'].eval())
self.assertEqual(4, names_to_updates['m2'].eval())
@@ -6914,7 +6914,7 @@ class CountTest(test.TestCase):
self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6931,7 +6931,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6952,7 +6952,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6979,7 +6979,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7001,7 +7001,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -7028,7 +7028,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7101,7 +7101,7 @@ class CohenKappaTest(test.TestCase):
(10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -7135,7 +7135,7 @@ class CohenKappaTest(test.TestCase):
for dtype in dtypes:
for shape in shapes:
for weight in weights:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
np.reshape(predictions, shape), dtype=dtype)
labels_tensor = constant_op.constant(
@@ -7156,7 +7156,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
expect = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7175,7 +7175,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
expect = -0.333333333333
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7193,7 +7193,7 @@ class CohenKappaTest(test.TestCase):
# labels, predictions, sample_weight=weights)
expect = 0.453466583385
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(
@@ -7218,7 +7218,7 @@ class CohenKappaTest(test.TestCase):
weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for idx in range(0, num_samples, batch_size):
@@ -7256,7 +7256,7 @@ class CohenKappaTest(test.TestCase):
def testConditionalPackingOptimization(self):
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
values, update_op = metric_ops.streaming_concat(placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for feed in range(10):
sess.run(update_op, feed_dict={placeholder: [feed]})
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
index e85ae7b22a..586c6c7bfc 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
@@ -37,7 +37,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
@@ -61,7 +61,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 62996d1fd8..9a9d480260 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -31,9 +31,11 @@ tf_custom_op_library(
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
]),
- deps = if_cuda([
+ deps = [] + if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:protos_all_proto_text",
]),
)
@@ -57,32 +59,31 @@ tf_cuda_cc_test(
"notap",
],
deps =
- [
+ if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- "@local_config_nccl//:nccl",
- ],
+ ]),
)
tf_kernel_library(
name = "nccl_kernels",
- srcs = [
+ srcs = if_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
"kernels/nccl_rewrite.cc",
- ],
- deps = [
+ ]),
+ deps = if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
- "//tensorflow/core:proto_text",
"//tensorflow/core:stream_executor",
- "@local_config_nccl//:nccl",
- ],
+ ]),
alwayslink = 1,
)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
index 4676e937e5..06ff86e6d8 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
diff --git a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
index cb69c72970..d0955cbe11 100644
--- a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
+++ b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
@@ -31,7 +31,7 @@ class HyperplaneLshProbesTest(test.TestCase):
# tests in hyperplane_lsh_probes_test.cc already cover most of the LSH
# functionality.
def simple_batch_test(self):
- with self.test_session():
+ with self.cached_session():
hyperplanes = np.eye(4)
points = np.array([[1.2, 0.5, -0.9, -1.0], [2.0, -3.0, 1.0, -1.5]])
product = np.dot(points, hyperplanes)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index c333d1e089..25ec475499 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -64,18 +64,17 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- # TODO(isaprykin): Delete colocate_with(v) from other optimizers and
- # confirm that colocation will happen anyway.
dtype = v.dtype.base_dtype
if v.get_shape().is_fully_defined():
init = init_ops.constant_initializer(self._initial_accumulator_value,
dtype=dtype)
else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(
- array_ops.shape(v), self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator")
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f6ecaba834..6af59dcfbf 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -214,7 +214,8 @@ class _OptimizerV2State(object):
# with that Tensor cast to that dtype.
with ops.init_scope():
self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in hyper.items() if not dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information
@@ -231,7 +232,8 @@ class _OptimizerV2State(object):
ret._deferred_dependencies = self._deferred_dependencies
ret._deferred_slot_restorations = self._deferred_slot_restorations
ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in hyper.items() if dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
ret._hyper.update(self._hyper)
ret._non_slot_devices = non_slot_devices
ret._distribution = distribution
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 31a6fe1d94..9a19502276 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -38,7 +38,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([6, None])
output_tensor = input_tensor.reshape((6, 2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -49,7 +49,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([5, None])
output_tensor = input_tensor.reshape((6, 2))[:-1]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -63,7 +63,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[15]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
# input_tensor[0, 0, 0] == result[0, 0, 0]
@@ -88,14 +88,14 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -109,7 +109,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([4, 4, None])
result_shape = (4, 4, 1)
input_shape = (2, 2, 4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=input_shape)
output = periodic_resample(x, desired_shape)
error = gradient_checker.compute_gradient_error(
@@ -117,7 +117,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
self.assertLess(error, 1e-4)
def testPeriodicResampleShapeInference(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Case 1: output shape can be fully inferreed.
x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
output = periodic_resample(x, [4, 4, None])
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index 95da6d04ed..03399396df 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -23,7 +23,6 @@ import logging
from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
@@ -68,23 +67,19 @@ def _get_signature_def(signature_def_key, export_dir, tags):
metagraph_def = get_meta_graph_def(export_dir, tags)
try:
- signature_def = signature_def_utils.get_signature_def_by_key(
- metagraph_def,
+ signature_def = metagraph_def.signature_def[signature_def_key]
+ except KeyError as e:
+ formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
signature_def_key)
- except ValueError as e:
try:
- formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
- signature_def_key)
- signature_def = signature_def_utils.get_signature_def_by_key(
- metagraph_def, formatted_key)
-
- logging.warning('Could not find signature def "%s". '
- 'Using "%s" instead', signature_def_key, formatted_key)
- except ValueError:
+ signature_def = metagraph_def.signature_def[formatted_key]
+ except KeyError:
raise ValueError(
'Got signature_def_key "{}". Available signatures are {}. '
'Original error:\n{}'.format(
signature_def_key, list(metagraph_def.signature_def), e))
+ logging.warning('Could not find signature def "%s". '
+ 'Using "%s" instead', signature_def_key, formatted_key)
return signature_def
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
index 00fbd4fbb8..aea80a5256 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -56,7 +56,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
x_power=state.x_power * theta.x)
return next_state, []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
theta = _PolyTheta(x=array_ops.constant(2.0))
state = _PolyState(
value=array_ops.constant(0.0),
@@ -142,7 +142,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
def _ParameterizedTestElman(self, seqlen, use_grad):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(342462)
batch = 3
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index aa4562be7c..bf699db3ed 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase):
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e41e..8d34b9e852 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase):
def testBasicRNNFusedWrapper(self):
"""This test checks that using a wrapper for BasicRNN works as expected."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase):
self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testTimeReversedFusedRNN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 2df8f0ec05..6689664fb9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.util import nest
class RNNCellTest(test.TestCase):
def testCoupledInputForgetGateLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
state_size = num_units * 2
batch_size = 3
@@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_state)
def testTimeFreqLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
state_size = num_units * 2
batch_size = 3
@@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
@@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase):
.state_f00_b00_c[i, :]))) > 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
feature_size = 2
@@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase):
]],
dtype=np.float32)
for state_is_tuple in [False, True]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCellWithSliceOffset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase):
batch_size = 3
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase):
0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase):
0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase):
[[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase):
[[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 1]
filter_size = [3]
num_features = 1
@@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 1]
filter_size = [3, 3]
num_features = 1
@@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 2, 1]
filter_size = [3, 3, 3]
num_features = 1
@@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase):
# Try with input dimension equal to num_units or not.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root1_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase):
# Try with num_inputs equal to or not equal to num_units.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root2_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase):
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
def testBasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_h, 1e-5)
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
num_units = 5
allowed_low = [1, 2, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(1)):
x = array_ops.zeros([1, 5])
@@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase):
def _cell_output(self, cell):
"""Calculates cell output."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index b897224c6d..4ca5274b2e 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -78,23 +78,6 @@ py_test(
],
)
-py_test(
- name = "signature_def_utils_test",
- size = "small",
- srcs = ["python/saved_model/signature_def_utils_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":saved_model_py",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:signature_def_utils",
- "//tensorflow/python/saved_model:utils",
- ],
-)
-
py_library(
name = "keras_saved_model",
srcs = ["python/saved_model/keras_saved_model.py"],
@@ -123,6 +106,7 @@ py_test(
size = "medium",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
+ tags = ["notsan"],
deps = [
":keras_saved_model",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 074dc655ac..ac95e38011 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -25,13 +25,11 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
-from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
- "get_signature_def_by_key",
"load_keras_model",
"save_keras_model"]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
index e3b76bb6f3..fd3dc1d7aa 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -25,5 +25,4 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 8a0dbef788..12dd72a95b 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -50,7 +50,7 @@ class TestModelSavingandLoading(test.TestCase):
return os.path.join(temp_dir, dirname)
def test_saving_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -75,7 +75,7 @@ class TestModelSavingandLoading(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_saving_sequential_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -92,7 +92,7 @@ class TestModelSavingandLoading(test.TestCase):
self.assertAllClose(ref_y, y, atol=1e-05)
def test_saving_functional_model(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -117,7 +117,7 @@ class TestModelSavingandLoading(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_saving_functional_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -138,7 +138,7 @@ class TestModelSavingandLoading(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_saving_with_tf_optimizer(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
deleted file mode 100644
index f521647999..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SignatureDef utility functions implementation."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-def get_signature_def_by_key(meta_graph_def, signature_def_key):
- """Utility function to get a SignatureDef protocol buffer by its key.
-
- Args:
- meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefMap to
- look up.
- signature_def_key: Key of the SignatureDef protocol buffer to find in the
- SignatureDefMap.
-
- Returns:
- A SignatureDef protocol buffer corresponding to the supplied key, if it
- exists.
-
- Raises:
- ValueError: If no entry corresponding to the supplied key is found in the
- SignatureDefMap of the MetaGraphDef.
- """
- if signature_def_key not in meta_graph_def.signature_def:
- raise ValueError("No SignatureDef with key '%s' found in MetaGraphDef." %
- signature_def_key)
- return meta_graph_def.signature_def[signature_def_key]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
deleted file mode 100644
index d2e14f73e4..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for SignatureDef utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils as signature_def_contrib_utils
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.saved_model import utils
-
-
-class SignatureDefUtilsTest(test.TestCase):
-
- def _add_to_signature_def_map(self, meta_graph_def, signature_def_map=None):
- if signature_def_map is not None:
- for key in signature_def_map:
- meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
-
- def _check_tensor_info(self, tensor_info_map, map_key, expected_tensor_name):
- actual_tensor_info = tensor_info_map[map_key]
- self.assertEqual(expected_tensor_name, actual_tensor_info.name)
-
- def testGetSignatureDefByKey(self):
- x = array_ops.placeholder(dtypes.float32, 1, name="x")
- x_tensor_info = utils.build_tensor_info(x)
-
- y = array_ops.placeholder(dtypes.float32, name="y")
- y_tensor_info = utils.build_tensor_info(y)
-
- foo_signature_def = signature_def_utils.build_signature_def({
- "foo-input": x_tensor_info
- }, {"foo-output": y_tensor_info}, "foo-method-name")
- bar_signature_def = signature_def_utils.build_signature_def({
- "bar-input": x_tensor_info
- }, {"bar-output": y_tensor_info}, "bar-method-name")
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(
- meta_graph_def, {"foo": foo_signature_def,
- "bar": bar_signature_def})
-
- # Look up a key that does not exist in the SignatureDefMap.
- missing_key = "missing-key"
- with self.assertRaisesRegexp(
- ValueError,
- "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key):
- signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, missing_key)
-
- # Look up the key, `foo` which exists in the SignatureDefMap.
- foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "foo")
- self.assertTrue("foo-method-name", foo_signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(foo_signature_def.inputs))
- self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(foo_signature_def.outputs))
- self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")
-
- # Look up the key, `bar` which exists in the SignatureDefMap.
- bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "bar")
- self.assertTrue("bar-method-name", bar_signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(bar_signature_def.inputs))
- self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(bar_signature_def.outputs))
- self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
-
- def testGetSignatureDefByKeyRegression(self):
- input1 = constant_op.constant("a", name="input-1")
- output1 = constant_op.constant(7.2, name="output-1")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_regression":
- signature_def_utils.regression_signature_def(input1, output1)
- })
-
- # Look up the regression signature with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_regression")
-
- # Check the method name to match the constants regression method name.
- self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs,
- signature_constants.REGRESS_INPUTS, "input-1:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs,
- signature_constants.REGRESS_OUTPUTS, "output-1:0")
-
- def testGetSignatureDefByKeyClassification(self):
- input1 = constant_op.constant("a", name="input-1")
- output1 = constant_op.constant("b", name="output-1")
- output2 = constant_op.constant(3.0, name="output-2")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_classification":
- signature_def_utils.classification_signature_def(
- input1, output1, output2)
- })
-
- # Look up the classification signature def with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_classification")
-
- # Check the method name to match the constants classification method name.
- self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs,
- signature_constants.CLASSIFY_INPUTS, "input-1:0")
-
- # Check outputs in signature def.
- self.assertEqual(2, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs,
- signature_constants.CLASSIFY_OUTPUT_CLASSES,
- "output-1:0")
- self._check_tensor_info(signature_def.outputs,
- signature_constants.CLASSIFY_OUTPUT_SCORES,
- "output-2:0")
-
- def testPredictionSignatureDef(self):
- input1 = constant_op.constant("a", name="input-1")
- input2 = constant_op.constant("b", name="input-2")
- output1 = constant_op.constant("c", name="output-1")
- output2 = constant_op.constant("d", name="output-2")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_prediction":
- signature_def_utils.predict_signature_def({
- "input-1": input1,
- "input-2": input2
- }, {"output-1": output1,
- "output-2": output2})
- })
-
- # Look up the prediction signature def with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_prediction")
- self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(2, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs, "input-1", "input-1:0")
- self._check_tensor_info(signature_def.inputs, "input-2", "input-2:0")
-
- # Check outputs in signature def.
- self.assertEqual(2, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs, "output-1", "output-1:0")
- self._check_tensor_info(signature_def.outputs, "output-2", "output-2:0")
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index f5b6b1bde9..5e28e651c6 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -248,6 +248,7 @@ class TestBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@@ -258,7 +259,8 @@ class TestBeamStep(test.TestCase):
lengths=constant_op.constant(
2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
finished=array_ops.zeros(
- [self.batch_size, self.beam_width], dtype=dtypes.bool))
+ [self.batch_size, self.beam_width], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -281,7 +283,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -313,7 +316,8 @@ class TestBeamStep(test.TestCase):
lengths=ops.convert_to_tensor(
[[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
finished=ops.convert_to_tensor(
- [[False, True, False], [False, False, True]], dtype=dtypes.bool))
+ [[False, True, False], [False, False, True]], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -336,7 +340,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -372,6 +377,7 @@ class TestLargeBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
@@ -411,7 +417,8 @@ class TestLargeBeamStep(test.TestCase):
cell_state=dummy_cell_state,
log_probs=log_probs,
lengths=_lengths,
- finished=_finished)
+ finished=_finished,
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -434,7 +441,8 @@ class TestLargeBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, _, _ = sess.run(
@@ -476,7 +484,9 @@ class BeamSearchDecoderTest(test.TestCase):
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
+ coverage_penalty_weight = 0.0
if has_attention:
+ coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
@@ -508,7 +518,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state=cell_state,
beam_width=beam_width,
output_layer=output_layer,
- length_penalty_weight=0.0)
+ length_penalty_weight=0.0,
+ coverage_penalty_weight=coverage_penalty_weight)
final_outputs, final_state, final_sequence_lengths = (
decoder.dynamic_decode(
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 74741a7bd6..605e3143fd 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import numpy as np
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@@ -49,7 +50,8 @@ __all__ = [
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
- ("cell_state", "log_probs", "finished", "lengths"))):
+ ("cell_state", "log_probs", "finished", "lengths",
+ "accumulated_attention_probs"))):
pass
@@ -260,6 +262,10 @@ class BeamSearchDecoder(decoder.Decoder):
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
+
+ Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
+ when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
+ the translation to cover all inputs.
"""
def __init__(self,
@@ -271,6 +277,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width,
output_layer=None,
length_penalty_weight=0.0,
+ coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
@@ -286,6 +293,8 @@ class BeamSearchDecoder(decoder.Decoder):
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
@@ -326,6 +335,7 @@ class BeamSearchDecoder(decoder.Decoder):
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
+ self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
@@ -411,13 +421,18 @@ class BeamSearchDecoder(decoder.Decoder):
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
+ init_attention_probs = get_attention_probs(
+ self._initial_cell_state, self._coverage_penalty_weight)
+ if init_attention_probs is None:
+ init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.int64))
+ [self._batch_size, self._beam_width], dtype=dtypes.int64),
+ accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@@ -631,6 +646,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width = self._beam_width
end_token = self._end_token
length_penalty_weight = self._length_penalty_weight
+ coverage_penalty_weight = self._coverage_penalty_weight
with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
@@ -655,7 +671,8 @@ class BeamSearchDecoder(decoder.Decoder):
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
@@ -667,7 +684,8 @@ class BeamSearchDecoder(decoder.Decoder):
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
- beam_width, end_token, length_penalty_weight):
+ beam_width, end_token, length_penalty_weight,
+ coverage_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
@@ -684,6 +702,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width: Python int. The size of the beams.
end_token: The int32 end token.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
Returns:
A new beam state.
@@ -693,6 +713,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# Calculate the current lengths of the predictions
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
+ not_finished = math_ops.logical_not(previously_finished)
# Calculate the total log probs for the new hypotheses
# Final Shape: [batch_size, beam_width, vocab_size]
@@ -708,16 +729,29 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
on_value=np.int64(0),
off_value=np.int64(1),
dtype=dtypes.int64)
- add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+ add_mask = math_ops.to_int64(not_finished)
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
+ # Calculate the accumulated attention probabilities if coverage penalty is
+ # enabled.
+ accumulated_attention_probs = None
+ attention_probs = get_attention_probs(
+ next_cell_state, coverage_penalty_weight)
+ if attention_probs is not None:
+ attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2)
+ accumulated_attention_probs = (
+ beam_state.accumulated_attention_probs + attention_probs)
+
# Calculate the scores for each beam
scores = _get_scores(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight,
+ finished=previously_finished,
+ accumulated_attention_probs=accumulated_attention_probs)
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
@@ -775,6 +809,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
range_size=beam_width,
gather_shape=[-1])
next_prediction_len += lengths_to_add
+ next_accumulated_attention_probs = ()
+ if accumulated_attention_probs is not None:
+ next_accumulated_attention_probs = _tensor_gather_helper(
+ gather_indices=next_beam_ids,
+ gather_from=accumulated_attention_probs,
+ batch_size=batch_size,
+ range_size=beam_width,
+ gather_shape=[batch_size * beam_width, -1],
+ name="next_accumulated_attention_probs")
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
@@ -795,7 +838,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
- finished=next_finished)
+ finished=next_finished,
+ accumulated_attention_probs=next_accumulated_attention_probs)
output = BeamSearchDecoderOutput(
scores=next_beam_scores,
@@ -805,7 +849,53 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
return output, next_state
-def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
+def get_attention_probs(next_cell_state, coverage_penalty_weight):
+ """Get attention probabilities from the cell state.
+
+ Args:
+ next_cell_state: The next state from the cell, e.g. an instance of
+ AttentionWrapperState if the cell is attentional.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+
+ Returns:
+ The attention probabilities with shape `[batch_size, beam_width, max_time]`
+ if coverage penalty is enabled. Otherwise, returns None.
+
+ Raises:
+ ValueError: If no cell is attentional but coverage penalty is enabled.
+ """
+ if coverage_penalty_weight == 0.0:
+ return None
+
+ # Attention probabilities of each attention layer. Each with shape
+ # `[batch_size, beam_width, max_time]`.
+ probs_per_attn_layer = []
+ if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
+ elif isinstance(next_cell_state, tuple):
+ for state in next_cell_state:
+ if isinstance(state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer.append(attention_probs_from_attn_state(state))
+
+ if not probs_per_attn_layer:
+ raise ValueError(
+ "coverage_penalty_weight must be 0.0 if no cell is attentional.")
+
+ if len(probs_per_attn_layer) == 1:
+ attention_probs = probs_per_attn_layer[0]
+ else:
+ # Calculate the average attention probabilities from all attention layers.
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+
+ return attention_probs
+
+
+def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
+ coverage_penalty_weight, finished, accumulated_attention_probs):
"""Calculates scores for beam search hypotheses.
Args:
@@ -813,13 +903,78 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
`[batch_size, beam_width, vocab_size]`.
sequence_lengths: The array of sequence lengths.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+ finished: A boolean tensor of shape `[batch_size, beam_width]` that
+ specifies which elements in the beam are finished already.
+ accumulated_attention_probs: Accumulated attention probabilities up to the
+ current time step, with shape `[batch_size, beam_width, max_time]` if
+ coverage_penalty_weight is not 0.0.
Returns:
- The scores normalized by the length_penalty.
+ The scores normalized by the length_penalty and coverage_penalty.
+
+ Raises:
+ ValueError: accumulated_attention_probs is None when coverage penalty is
+ enabled.
"""
length_penalty_ = _length_penalty(
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
- return log_probs / length_penalty_
+ scores = log_probs / length_penalty_
+
+ coverage_penalty_weight = ops.convert_to_tensor(
+ coverage_penalty_weight, name="coverage_penalty_weight")
+ if coverage_penalty_weight.shape.ndims != 0:
+ raise ValueError("coverage_penalty_weight should be a scalar, "
+ "but saw shape: %s" % coverage_penalty_weight.shape)
+
+ if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
+ return scores
+
+ if accumulated_attention_probs is None:
+ raise ValueError(
+ "accumulated_attention_probs can be None only if coverage penalty is "
+ "disabled.")
+
+ # Add source sequence length mask before computing coverage penalty.
+ accumulated_attention_probs = array_ops.where(
+ math_ops.equal(accumulated_attention_probs, 0.0),
+ array_ops.ones_like(accumulated_attention_probs),
+ accumulated_attention_probs)
+
+ # coverage penalty =
+ # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
+ coverage_penalty = math_ops.reduce_sum(
+ math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
+ # Apply coverage penalty to finished predictions.
+ coverage_penalty *= math_ops.to_float(finished)
+ weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
+ # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
+ weighted_coverage_penalty = array_ops.expand_dims(
+ weighted_coverage_penalty, 2)
+ return scores + weighted_coverage_penalty
+
+
+def attention_probs_from_attn_state(attention_state):
+ """Calculates the average attention probabilities.
+
+ Args:
+ attention_state: An instance of `AttentionWrapperState`.
+
+ Returns:
+ The attention probabilities in the given AttentionWrapperState.
+ If there're multiple attention mechanisms, return the average value from
+ all attention mechanisms.
+ """
+ # Attention probabilities over time steps, with shape
+ # `[batch_size, beam_width, max_time]`.
+ attention_probs = attention_state.alignments
+ if isinstance(attention_probs, tuple):
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in attention_probs]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+ return attention_probs
def _length_penalty(sequence_lengths, penalty_factor):
diff --git a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
index 1bb6fbc570..795de6a408 100644
--- a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
@@ -88,7 +88,7 @@ class DatasetDataProviderTest(test.TestCase):
height = 300
width = 280
- with self.test_session():
+ with self.cached_session():
test_dataset = _create_tfrecord_dataset(dataset_dir)
provider = dataset_data_provider.DatasetDataProvider(test_dataset)
key, image, label = provider.get(['record_key', 'image', 'label'])
@@ -111,7 +111,7 @@ class DatasetDataProviderTest(test.TestCase):
height = 300
width = 280
- with self.test_session():
+ with self.cached_session():
provider = dataset_data_provider.DatasetDataProvider(
_create_tfrecord_dataset(dataset_dir))
[image] = provider.get(['image'])
@@ -128,7 +128,7 @@ class DatasetDataProviderTest(test.TestCase):
dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
'tfrecord_dataset'))
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
dataset_data_provider.DatasetDataProvider(
_create_tfrecord_dataset(dataset_dir), record_key='image')
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index ea8cc0ff61..c457d44e07 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -39,7 +39,7 @@ class ParallelReaderTest(test.TestCase):
ops.reset_default_graph()
def _verify_all_data_sources_read(self, shared_queue):
- with self.test_session():
+ with self.cached_session():
tfrecord_paths = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=3)
@@ -76,7 +76,7 @@ class ParallelReaderTest(test.TestCase):
self.assertEquals(count0 + count1 + count2, num_reads)
def _verify_read_up_to_out(self, shared_queue):
- with self.test_session():
+ with self.cached_session():
num_files = 3
num_records_per_file = 7
tfrecord_paths = test_utils.create_tfrecord_files(
@@ -161,7 +161,7 @@ class ParallelReadTest(test.TestCase):
ops.reset_default_graph()
def testTFRecordReader(self):
- with self.test_session():
+ with self.cached_session():
self._tfrecord_paths = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=3)
@@ -188,7 +188,7 @@ class SinglePassReadTest(test.TestCase):
ops.reset_default_graph()
def testOutOfRangeError(self):
- with self.test_session():
+ with self.cached_session():
[tfrecord_path] = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=1)
@@ -196,7 +196,7 @@ class SinglePassReadTest(test.TestCase):
tfrecord_path, reader_class=io_ops.TFRecordReader)
init_op = variables.local_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with queues.QueueRunners(sess):
num_reads = 11
@@ -205,7 +205,7 @@ class SinglePassReadTest(test.TestCase):
sess.run([key, value])
def testTFRecordReader(self):
- with self.test_session():
+ with self.cached_session():
[tfrecord_path] = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=1)
@@ -213,7 +213,7 @@ class SinglePassReadTest(test.TestCase):
tfrecord_path, reader_class=io_ops.TFRecordReader)
init_op = variables.local_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with queues.QueueRunners(sess):
flowers = 0
diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
index 6c3e57c47d..7caa42dcb9 100644
--- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.training import queue_runner_impl
class PrefetchQueueTest(test.TestCase):
def testOneThread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 5
@@ -74,7 +74,7 @@ class PrefetchQueueTest(test.TestCase):
thread.join()
def testMultiThread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 5
@@ -114,7 +114,7 @@ class PrefetchQueueTest(test.TestCase):
thread.join()
def testMultipleDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 4
@@ -162,7 +162,7 @@ class PrefetchQueueTest(test.TestCase):
prefetch_queue.prefetch_queue([variable_tensor])
def testDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create 3 tensors of variable but compatible shapes.
var_shape = [None, 2]
p1 = constant_op.constant([[1, 2], [3, 4]])
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 826242c9d7..3114949b82 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -45,7 +45,7 @@ class TFExampleDecoderTest(test.TestCase):
int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
- with self.test_session():
+ with self.cached_session():
encoded = tf_encoded.eval()
def BytesList(value):
@@ -133,7 +133,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_image = self.DecodeExample(serialized_example, item_handler,
image_format)
- with self.test_session():
+ with self.cached_session():
decoded_image = tf_image.eval()
# We need to recast them here to avoid some issues with uint8.
@@ -265,7 +265,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
@@ -296,7 +296,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
@@ -319,7 +319,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
@@ -342,7 +342,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -366,7 +366,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
@@ -390,7 +390,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -423,7 +423,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -468,7 +468,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -505,7 +505,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -536,7 +536,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -567,7 +567,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -598,7 +598,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -625,7 +625,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -657,7 +657,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -692,7 +692,7 @@ class TFExampleDecoderTest(test.TestCase):
image, serialized_example = self.GenerateImage(
image_format=image_encoding, image_shape=image_shape)
- with self.test_session():
+ with self.cached_session():
def ConditionalDecoding(keys_to_tensors):
"""See base class."""
@@ -759,7 +759,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -800,7 +800,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -837,7 +837,7 @@ class TFExampleDecoderTest(test.TestCase):
image, _ = self.GenerateImage(
image_format=image_format, image_shape=image_shape)
tf_encoded = self._Encoder(image, image_format)
- with self.test_session():
+ with self.cached_session():
tf_string = tf_encoded.eval()
example = example_pb2.Example(
@@ -852,7 +852,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
decoder = tfexample_decoder.TFExampleDecoder(
@@ -885,7 +885,7 @@ class TFExampleDecoderTest(test.TestCase):
table = lookup_ops.index_table_from_tensor(
constant_op.constant(['dog', 'guinea pig', 'cat']))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(lookup_ops.tables_initializer())
serialized_example = array_ops.reshape(serialized_example, shape=[])
@@ -943,7 +943,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
obtained_class_ids_each_example = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(lookup_ops.tables_initializer())
for example in [example1, example2, example3]:
serialized_example = array_ops.reshape(
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
index 4707dc2229..8fcd7aeef6 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
@@ -47,7 +47,7 @@ def _get_lanczos_tests(dtype_, use_static_shape_, shape_, orthogonalize_,
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
tol = 1e-12 if dtype_ == np.float64 else 1e-5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
else:
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
index a73642716b..2a9100903a 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
@@ -47,7 +47,7 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_):
low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
tol = 1e-12 if dtype_ == np.float64 else 1e-6
max_iter = 20
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
index a1282847be..a0e6eb87bc 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
@@ -54,7 +54,7 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_):
x_np = np.zeros_like(rhs_np)
tol = 1e-6 if dtype_ == np.float64 else 1e-3
max_iter = 20
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 5d7534657b..57b4996689 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -33,7 +33,7 @@ class UtilTest(test.TestCase):
a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
x_np = np.array([[2.], [-3.]], dtype=dtype)
y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np, dtype=dtype)
x = constant_op.constant(x_np, dtype=dtype)
@@ -68,7 +68,7 @@ class UtilTest(test.TestCase):
a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
x_np = np.array([[2.], [-3.]], dtype=dtype)
y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np, dtype=dtype)
x = constant_op.constant(x_np, dtype=dtype)
@@ -101,7 +101,7 @@ class UtilTest(test.TestCase):
self._testIdentityOperator(False)
def testL2Norm(self):
- with self.test_session():
+ with self.cached_session():
x_np = np.array([[2], [-3.], [5.]])
x_norm_np = np.linalg.norm(x_np)
x_normalized_np = x_np / x_norm_np
diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py
index 9a4ad36793..b7ce6aa20a 100644
--- a/tensorflow/contrib/specs/python/specs_test.py
+++ b/tensorflow/contrib/specs/python/specs_test.py
@@ -38,7 +38,7 @@ def _rand(*size):
class SpecsTest(test.TestCase):
def testSimpleConv(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -53,7 +53,7 @@ class SpecsTest(test.TestCase):
def testUnary(self):
# This is just a quick and dirty check that these ops exist
# and work as unary ops.
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax"
outputs = specs.create_net(spec, inputs)
@@ -63,7 +63,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(tuple(result.shape), (17, 55))
def testAdd(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Fs(10) + Fr(10)"
outputs = specs.create_net(spec, inputs)
@@ -77,7 +77,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd relu add")
def testMpPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "M2 = Mp([2, 2]); net = M2**3"
outputs = specs.create_net(spec, inputs)
@@ -90,7 +90,7 @@ class SpecsTest(test.TestCase):
"_ maxpool maxpool maxpool")
def testAbbrevPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3"
outputs = specs.create_net(spec, inputs)
@@ -106,7 +106,7 @@ class SpecsTest(test.TestCase):
" biasadd relu maxpool")
def testAbbrevPower2(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);"
spec += "net = (C3(_0=5) | M2)**3"
@@ -123,7 +123,7 @@ class SpecsTest(test.TestCase):
" maxpool")
def testConc(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "net = Conc(1, Fs(20), Fs(10))"
outputs = specs.create_net(spec, inputs)
@@ -137,7 +137,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd sig _ concatv2")
def testImport(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = ("S = Import('from tensorflow.python.ops" +
" import math_ops; f = math_ops.sigmoid')")
@@ -150,7 +150,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig")
def testKeywordRestriction(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "import re; net = Conc(1, Fs(20), Fs(10))"
self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs))
@@ -179,7 +179,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testVar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with specs.ops:
# pylint: disable=undefined-variable
v = Var("test_var",
@@ -196,7 +196,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testShared(self):
- with self.test_session():
+ with self.cached_session():
with specs.ops:
# pylint: disable=undefined-variable
f = Shared(Fr(100))
diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py
index 34ff4bc8ca..b82ba06d3f 100644
--- a/tensorflow/contrib/specs/python/summaries_test.py
+++ b/tensorflow/contrib/specs/python/summaries_test.py
@@ -34,7 +34,7 @@ def _rand(*size):
class SummariesTest(test.TestCase):
def testStructure(self):
- with self.test_session():
+ with self.cached_session():
inputs_shape = (1, 18, 19, 5)
inputs = constant_op.constant(_rand(*inputs_shape))
spec = "net = Cr(64, [5, 5])"
@@ -48,7 +48,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testStructureFromTensor(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -60,7 +60,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testPrint(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -70,7 +70,7 @@ class SummariesTest(test.TestCase):
summaries.tf_spec_print(spec, inputs)
def testSummary(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index f80a34ece6..fe2c91c104 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -246,7 +246,8 @@ class ProcessInputOp : public OpKernel {
const Tensor& input_weights = context->input(7);
const Tensor& leaf_ids_tensor = context->input(8);
- std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
+ std::unique_ptr<TensorDataSet> data_set(
+ new TensorDataSet(input_spec_, random_seed_));
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 122a67a407..9e8979bce4 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -19,6 +19,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -181,7 +182,12 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":wrap_conversion",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python:session",
"//tensorflow/python:tf_optimizer",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -410,6 +416,31 @@ py_library(
],
)
+cuda_py_test(
+ name = "trt_convert_test",
+ srcs = ["python/trt_convert_test.py"],
+ additional_deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow/python/saved_model:utils",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ "//tensorflow/python/tools:saved_model_utils",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+)
+
cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 4116f2fe30..369e73b5a6 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
import six as _six
+# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
@@ -28,55 +28,179 @@ from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_vers
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
-# pylint: enable=unused-import,line-too-long
+
+if _six.PY2:
+ _to_bytes = lambda s: s
+ _to_string = lambda s: s
+else:
+ _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+ _to_string = lambda s: s.decode("utf-8")
+
+
+class TrtPrecisionMode(object):
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT8 = "INT8"
+
+ @staticmethod
+ def supported_precision_modes():
+ return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+def tensorrt_rewriter_config(max_batch_size=1,
+ max_workspace_size_bytes=2 << 20,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batch_sizes=None):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+
+ Returns:
+ A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+ Raises:
+ TypeError: if the provided precision mode is invalid.
+ ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines.
+ """
+ if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(
+ precision_mode,
+ TrtPrecisionMode.supported_precision_modes))
+
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batch_sizes:
+ if not isinstance(cached_engine_batch_sizes, list):
+ raise TypeError("cached_engine_batch_sizes should be a list.")
+ if len(cached_engine_batch_sizes) > maximum_cached_engines:
+ raise ValueError("cached_engine_batch_sizes should not contain more than "
+ "maximum_cached_engines items.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batch_sizes)
+ return rewriter_cfg
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
max_workspace_size_bytes=2 << 20,
- precision_mode="FP32",
+ precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=None):
+ cached_engine_batch_sizes=None,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
"""Python wrapper for the TRT transformation.
Args:
- input_graph_def: GraphDef object containing a model to be transformed.
- outputs: list of tensors or node names for the model outputs.
- max_batch_size: max size for the input batch
- max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
- precision_mode: one of 'FP32', 'FP16' and 'INT8'
+ input_graph_def: a GraphDef object containing a model to be transformed. If
+ set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ outputs: list of tensors or node names for the model outputs. Only used when
+ input_graph_def is not None.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- cached_engine_batches: batch sizes used to pre-create cached engines.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ output_saved_model_dir: if not None, construct a SavedModel using the
+ returned GraphDef and save it to the specified directory. This option only
+ works when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None.
+ session_config: the ConfigProto used to create a Session. If not specified,
+ a default ConfigProto will be used.
Returns:
- New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+ A GraphDef transformed from input_graph_def (or the SavedModel graph def
+ loaded from input_saved_model_dir, if input_graph_def is not present), where
+ all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+ function is added for each of the subgraphs.
+
+ If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+ subgraph GraphDef, which will be converted to a TRT engine at execution time
+ and the TRT engine will be cached for future usage. A new TRT engine will be
+ created each time when none of the cached engines match the input shapes. If
+ it fails to execute the TRT engine or the number of cached engines reaches
+ maximum_cached_engines, the op will fall back to call the corresponding TF
+ function.
+
+ If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+ engine created from the corresponding subgraph. No more engines will be
+ created on the fly, and the op will fall back to call the corresponding TF
+ function when it fails to execute the engine.
Raises:
- ValueError: if the provided precision mode is invalid.
- RuntimeError: if the returned status message is malformed.
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
"""
- supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
- if precision_mode.upper() not in supported_precision_modes:
- raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(
- precision_mode, "{'FP32', 'FP16', 'INT8'}"))
- mode = supported_precision_modes[precision_mode.upper()]
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
@@ -101,61 +225,111 @@ def create_inference_graph(input_graph_def,
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
- def py2bytes(inp):
- return inp
+ if session_config is None:
+ session_config = config_pb2.ConfigProto()
+
+ if input_saved_model_tags is None:
+ input_saved_model_tags = [tag_constants.SERVING]
+ saved_model_loader = None
+ grappler_meta_graph_def = None
- def py3bytes(inp):
- return inp.encode("utf-8", errors="surrogateescape")
+ if input_graph_def is None:
+ # Read from SavedModel and freeze the graph if necessary.
+ if input_saved_model_dir is None:
+ raise ValueError("input_graph_def and input_saved_model_dir cannot be "
+ "both None")
+ with ops.Graph().as_default():
+ with session.Session(config=session_config) as sess:
+ saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
+ input_meta_graph_def = saved_model_loader.load(sess,
+ input_saved_model_tags)
+ output_node_names = set()
- def py2string(inp):
- return inp
+ def _gather_names(tensor_info):
+ """Get the node names from a TensorInfo."""
+ return set(
+ [tensor_info[key].name.split(":")[0] for key in tensor_info])
- def py3string(inp):
- return inp.decode("utf-8")
+ # Get input and outputs from all SignatureDef.
+ for key in input_meta_graph_def.signature_def:
+ signature_def = input_meta_graph_def.signature_def[key]
+ output_node_names.update(_gather_names(signature_def.inputs))
+ output_node_names.update(_gather_names(signature_def.outputs))
- if _six.PY2:
- to_bytes = py2bytes
- to_string = py2string
+ # Freeze the variables in the SavedModel graph and copy the frozen
+ # graph over.
+ frozen_graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph.as_graph_def(add_shapes=True),
+ list(output_node_names))
+ grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+ # Copy the collections that are not variables.
+ for key in input_meta_graph_def.collection_def:
+ # TODO(laigd): currently we use the collection key to filter out
+ # collections that depend on variable ops, but this may miss some
+ # other user-defined collections. A better way would be to use
+ # CollectionDef::NodeList for the filtering.
+ if key not in [
+ "variables", "local_variables", "model_variables",
+ "trainable_variables", "train_op", "table_initializer"
+ ]:
+ grappler_meta_graph_def.collection_def[key].CopyFrom(
+ input_meta_graph_def.collection_def[key])
+
+ # Copy other information.
+ grappler_meta_graph_def.meta_info_def.CopyFrom(
+ input_meta_graph_def.meta_info_def)
+ for key in input_meta_graph_def.signature_def:
+ grappler_meta_graph_def.signature_def[key].CopyFrom(
+ input_meta_graph_def.signature_def[key])
+ # TODO(laigd): maybe add back AssetFileDef.
else:
- to_bytes = py3bytes
- to_string = py3string
-
- # Create MetaGraphDef
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(input_graph_def, name="")
- meta_graph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- if outputs:
- output_collection = meta_graph_pb2.CollectionDef()
- output_list = output_collection.node_list.value
- for i in outputs:
- if isinstance(i, ops.Tensor):
- output_list.append(to_bytes(i.name))
- else:
- output_list.append(to_bytes(i))
- meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+ if output_saved_model_dir is not None:
+ raise ValueError("output_saved_model_dir cannot be set when "
+ "input_graph_def is set")
+ # Create MetaGraphDef from input graph.
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(_to_bytes(i.name))
+ else:
+ output_list.append(_to_bytes(i))
+ # TODO(laigd): use another key as the outputs are really not train_op.
+ grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ output_collection)
# Create RewriterConfig.
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- optimizer = rewriter_cfg.custom_optimizers.add()
- optimizer.name = "TensorRTOptimizer"
- optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
- optimizer.parameter_map["max_batch_size"].i = max_batch_size
- optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
- optimizer.parameter_map[
- "max_workspace_size_bytes"].i = max_workspace_size_bytes
- optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
- optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
- if cached_engine_batches:
- if not isinstance(cached_engine_batches, list):
- raise TypeError("cached_engine_batches should be a list.")
- optimizer.parameter_map["cached_engine_batches"].list.i.extend(
- cached_engine_batches)
+ rewriter_cfg = tensorrt_rewriter_config(
+ max_batch_size, max_workspace_size_bytes, precision_mode,
+ minimum_segment_size, is_dynamic_op, maximum_cached_engines,
+ cached_engine_batch_sizes)
+
+ # Run Grappler.
+ transformed_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph")
- return tf_optimizer.OptimizeGraph(
- rewriter_cfg, meta_graph, graph_id=b"tf_graph")
+ # Optionally write the transformed graphdef as SavedModel.
+ if output_saved_model_dir is not None:
+ saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+ with ops.Graph().as_default():
+ importer.import_graph_def(transformed_graph_def, name="")
+ with session.Session(config=session_config) as sess:
+ saved_model_builder.add_meta_graph_and_variables(
+ sess,
+ input_saved_model_tags,
+ signature_def_map=grappler_meta_graph_def.signature_def)
+ # Ignore other meta graphs from the input SavedModel.
+ saved_model_builder.save()
+
+ return transformed_graph_def
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
@@ -164,22 +338,13 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
Args:
calibration_graph_def: the calibration GraphDef object with calibration data
is_dynamic_op: whether to create dynamic static engines from calibration
+
Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises:
RuntimeError: if the returned status message is malformed.
"""
- def py2string(inp):
- return inp
-
- def py3string(inp):
- return inp.decode("utf-8")
-
- if _six.PY2:
- to_string = py2string
- else:
- to_string = py3string
is_calib_graph = False
for n in calibration_graph_def.node:
if n.op == "TRTEngineOp":
@@ -190,7 +355,7 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
return None
graph_str = calibration_graph_def.SerializeToString()
out = calib_convert(graph_str, is_dynamic_op)
- status = to_string(out[0])
+ status = _to_string(out[0])
output_graph_def_string = out[1]
del graph_str # Save some memory
if len(status) < 2:
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
new file mode 100644
index 0000000000..118a6680fd
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+# pylint: disable=unused-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+# pylint: enable=unused-import
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
+from tensorflow.python.tools import saved_model_utils
+
+
+class TrtConvertTest(test_util.TensorFlowTestCase):
+ """Class to test Tensorflow-TensorRT integration python API."""
+
+ def testTensorrtRewriterConfig(self):
+ """Test case for trt_convert.tensorrt_rewriter_config()."""
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ max_batch_size=128,
+ max_workspace_size_bytes=1234,
+ precision_mode="INT8",
+ minimum_segment_size=10,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ cached_engine_batch_sizes=[1, 128])
+ trt_optimizer = None
+ for optimizer in rewriter_cfg.custom_optimizers:
+ if optimizer.name == "TensorRTOptimizer":
+ self.assertTrue(trt_optimizer is None)
+ trt_optimizer = optimizer
+ self.assertTrue(trt_optimizer is not None)
+ for key in [
+ "minimum_segment_size", "max_batch_size", "is_dynamic_op",
+ "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
+ "cached_engine_batches"
+ ]:
+ self.assertTrue(key in trt_optimizer.parameter_map)
+ self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
+ self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
+ self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
+ self.assertEqual(1234,
+ trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
+ self.assertEqual(
+ trt_convert._to_bytes("INT8"),
+ trt_optimizer.parameter_map["precision_mode"].s)
+ self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
+ self.assertEqual(
+ [1, 128],
+ trt_optimizer.parameter_map["cached_engine_batches"].list.i)
+
+ def _GetConfigProto(self):
+ """Get ConfigProto for session creation."""
+ config = config_pb2.ConfigProto(
+ gpu_options=config_pb2.GPUOptions(allow_growth=True))
+ return config
+
+ def _GetGraph(self):
+ """Get the graph for testing."""
+ g = ops.Graph()
+ with g.as_default():
+ with g.device("/GPU:0"):
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[None, 1, 1], name="input")
+ var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+ add = inp + var.value()
+ mul = inp * add
+ add = mul + add
+ out = array_ops.identity(add, name="output")
+ return g, var, inp, out
+
+ def _GetGraphDef(self):
+ """Get the graph def for testing."""
+ g, var, _, _ = self._GetGraph()
+ with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ graph_def = graph_util.convert_variables_to_constants(
+ sess, g.as_graph_def(add_shapes=True), ["output"])
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "v1": "Const",
+ "v1/read": "Identity",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+ return graph_def
+
+ def _WriteInputSavedModel(self, input_saved_model_dir):
+ """Write the saved model as an input for testing."""
+ g, var, inp, out = self._GetGraph()
+ signature_def = signature_def_utils.build_signature_def(
+ inputs={"myinput": utils.build_tensor_info(inp)},
+ outputs={"myoutput": utils.build_tensor_info(out)},
+ method_name=signature_constants.PREDICT_METHOD_NAME)
+ saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
+ with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ saved_model_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING],
+ signature_def_map={"mypredict": signature_def})
+ saved_model_builder.save()
+
+ def _TestCreateInferenceGraph(self,
+ input_saved_model_dir=None,
+ output_saved_model_dir=None):
+ """General method to test trt_convert.create_inference_graph()."""
+ input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
+ output_graph_def = trt_convert.create_inference_graph(
+ input_graph_def, ["output"],
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+ graph_defs_to_verify = [output_graph_def]
+ if output_saved_model_dir is not None:
+ saved_model_graph_def = saved_model_utils.get_meta_graph_def(
+ output_saved_model_dir, tag_constants.SERVING).graph_def
+ self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
+ graph_defs_to_verify.append(saved_model_graph_def)
+
+ for graph_def in graph_defs_to_verify:
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "input": "Placeholder",
+ "my_trt_op_0": "TRTEngineOp",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_BasicConversion(self):
+ """Test case for trt_convert.create_inference_graph()."""
+ if not trt_convert.is_tensorrt_enabled():
+ return
+
+ # Use GraphDef as input.
+ self._TestCreateInferenceGraph()
+
+ # Use SavedModel as input.
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir1")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ self._TestCreateInferenceGraph(input_saved_model_dir,
+ output_saved_model_dir)
+
+ def _TestRun(self, sess, batch_size, expect_engine_is_run):
+ trt_convert.clear_test_values("")
+ result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
+ self.assertAllEqual([[[4.0]]] * batch_size, result)
+ execute_engine_test_value = ("done" if expect_engine_is_run else "")
+ execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
+ self.assertEqual(execute_engine_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine"))
+ self.assertEqual(
+ execute_native_segment_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment"))
+
+ def testCreateInferenceGraph_MinimumSegmentSize(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ output_graph_def = trt_convert.create_inference_graph(
+ self._GetGraphDef(), ["output"],
+ minimum_segment_size=5,
+ is_dynamic_op=False)
+ node_name_to_op = {node.name: node.op for node in output_graph_def.node}
+ self.assertEqual({
+ "v1/read": "Const",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_DynamicOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ def testCreateInferenceGraph_StaticOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ max_batch_size=1,
+ is_dynamic_op=False,
+ maximum_cached_engines=2, # This is noop, added just for testing.
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 090aa8bdb0..d26f260086 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -191,7 +191,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
@@ -206,7 +206,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph,
outputs=["output"],
@@ -216,7 +216,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 65ca21cf37..fc647e4eb9 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -50,7 +49,7 @@ RunParams = namedtuple(
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
- "cached_engine_batches"
+ "cached_engine_batch_sizes"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -139,7 +138,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
- cached_engine_batches=None)
+ cached_engine_batch_sizes=None)
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
@@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- custom_op = rewriter_cfg.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
trt_params = self.GetConversionParams(run_params)
- custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
- custom_op.parameter_map["max_workspace_size_bytes"].i = (
- trt_params.max_workspace_size_bytes)
- custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
- custom_op.parameter_map["minimum_segment_size"].i = (
- trt_params.minimum_segment_size)
- custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
- custom_op.parameter_map["maximum_cached_engines"].i = (
- trt_params.maximum_cached_engines)
- if trt_params.cached_engine_batches:
- custom_op.parameter_map["cached_engine_batches"].list.i.extend(
- trt_params.cached_engine_batches)
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ trt_params.max_batch_size, trt_params.max_workspace_size_bytes,
+ trt_params.precision_mode, trt_params.minimum_segment_size,
+ trt_params.is_dynamic_op, trt_params.maximum_cached_engines,
+ trt_params.cached_engine_batch_sizes)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
@@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=trt_params.minimum_segment_size,
is_dynamic_op=trt_params.is_dynamic_op,
maximum_cached_engines=trt_params.maximum_cached_engines,
- cached_engine_batches=trt_params.cached_engine_batches)
+ cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes)
def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index d808945334..1d27fffc62 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -264,10 +264,10 @@ class ARModel(model.TimeSeriesModel):
elif (not isinstance(periodicities, list) and
not isinstance(periodicities, tuple)):
periodicities = [periodicities]
- self._periods = [int(p) for p in periodicities]
- for p in self._periods:
+ self._periodicities = [int(p) for p in periodicities]
+ for p in self._periodicities:
assert p > 0
- assert len(self._periods) or self.input_window_size
+ assert len(self._periodicities) or self.input_window_size
assert output_window_size > 0
def initialize_graph(self, input_statistics=None):
@@ -364,9 +364,9 @@ class ARModel(model.TimeSeriesModel):
input_feature_size = 0
output_window_features = []
output_feature_size = 0
- if self._periods:
+ if self._periodicities:
_, time_features = self._compute_time_features(times)
- num_time_features = self._buckets * len(self._periods)
+ num_time_features = self._buckets * len(self._periodicities)
time_features = array_ops.reshape(
time_features,
[batch_size,
@@ -849,12 +849,12 @@ class ARModel(model.TimeSeriesModel):
def _compute_time_features(self, time):
"""Compute some features on the time value."""
batch_size = array_ops.shape(time)[0]
- num_periods = len(self._periods)
+ num_periods = len(self._periodicities)
# Reshape to 3D.
periods = constant_op.constant(
- self._periods, shape=[1, 1, num_periods, 1], dtype=time.dtype)
+ self._periodicities, shape=[1, 1, num_periods, 1], dtype=time.dtype)
time = array_ops.reshape(time, [batch_size, -1, 1, 1])
- window_offset = time / self._periods
+ window_offset = time / self._periodicities
# Cast to appropriate type and scale to [0, 1) range
mod = (math_ops.cast(time % periods, self.dtype) * self._buckets /
math_ops.cast(periods, self.dtype))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 461fe22210..83260fc59a 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -216,6 +216,15 @@ class TimeSeriesRegressorTest(test.TestCase):
exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
+ def test_structural_ensemble_numpy_input(self):
+ numpy_data = {"times": numpy.arange(50),
+ "values": numpy.random.normal(size=[50])}
+ estimators.StructuralEnsembleRegressor(
+ num_features=1, periodicities=[], model_dir=self.get_temp_dir(),
+ config=_SeedRunConfig()).train(
+ input_pipeline.WholeDatasetInputFn(
+ input_pipeline.NumpyReader(numpy_data)),
+ steps=1)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index e65e7b74d4..647455ae42 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -122,7 +122,7 @@ class EvaluationMetricsTests(test.TestCase):
metric[1] for metric in outputs.eval_metric_ops.values()]
loss_mean, loss_update = metrics.mean(outputs.loss)
metric_update_ops.append(loss_update)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(sess, coord=coordinator)
variables.local_variables_initializer().run()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
index 703537abf0..f92148b788 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
@@ -88,7 +88,7 @@ class RandomWindowInputFnTests(test.TestCase):
window_size=window_size, batch_size=batch_size)
result, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
@@ -261,7 +261,7 @@ class WholeDatasetInputFnTests(test.TestCase):
def _whole_dataset_input_fn_test_template(
self, time_series_reader, num_features, num_samples):
result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables.local_variables_initializer())
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -340,7 +340,7 @@ class AllWindowInputFnTests(test.TestCase):
window_size=window_size)
features, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 9b593fecbb..03da2b82e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object):
statistics.total_observation_count,
math_ops.cast(
gen_math_ops.round(
- math_ops.cast(auxiliary_variables.max_time_seen -
- statistics.start_time + 1, self._dtype) /
+ math_ops.cast(max_time_seen_assign -
+ start_time_update + 1, self._dtype) /
inter_observation_duration_estimate), dtypes.int64))
per_chunk_stat_updates = control_flow_ops.group(
overall_feature_mean_update, overall_feature_var_update,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index 02d2524b66..c0de42b15b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -55,7 +55,7 @@ class MathUtilsTest(test.TestCase):
running_sum = running_sum + current_contribution
# pylint: enable=g-no-augmented-assignment
transition_power = numpy.dot(transition, transition_power)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.power_sums_tensor(
array_size, transition, addition).eval())
@@ -66,7 +66,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(powers.shape[0]):
result.append(numpy.linalg.matrix_power(matrix, powers[i]))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.matrix_to_powers(matrix, powers).eval(),
rtol=1e-5,
@@ -78,7 +78,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(batch.shape[0]):
result.append(numpy.linalg.matrix_power(batch[i], powers[i]))
- with self.test_session():
+ with self.cached_session():
# TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be
# made slightly more stable?
self.assertAllClose(result,
@@ -91,7 +91,7 @@ class MathUtilsTest(test.TestCase):
left_transpose = numpy.transpose(left, [0, 2, 1])
right = numpy.random.normal(size=[2, 3]).astype(numpy.float32)
expected_result = numpy.dot(left, right)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.batch_times_matrix(
left, right).eval())
@@ -114,7 +114,7 @@ class MathUtilsTest(test.TestCase):
right_transpose = numpy.transpose(right, [0, 2, 1])
expected_result = numpy.transpose(numpy.dot(right_transpose, left.T),
[0, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.matrix_times_batch(
left, right).eval())
@@ -132,7 +132,7 @@ class MathUtilsTest(test.TestCase):
adj_x=True, adj_y=True).eval())
def test_make_diagonal_undefined_shapes(self):
- with self.test_session():
+ with self.cached_session():
completely_undefined = array_ops.placeholder(dtype=dtypes.float32)
partly_undefined = array_ops.placeholder(
shape=[None, None], dtype=dtypes.float32)
@@ -152,7 +152,7 @@ class MathUtilsTest(test.TestCase):
[5., 6.]]}))
def test_make_diagonal_mostly_defined_shapes(self):
- with self.test_session():
+ with self.cached_session():
mostly_defined = array_ops.placeholder(
shape=[None, 2], dtype=dtypes.float32)
blocked = math_utils.block_diagonal([[[2.]],
@@ -192,7 +192,7 @@ class TestMakeToeplitzMatrix(test.TestCase):
def _test_make_toeplitz_matrix(self, inputs, output_expected):
output_tf = math_utils.make_toeplitz_matrix(inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_tf_np = sess.run(output_tf)
self.assertAllClose(output_tf_np, output_expected)
@@ -201,13 +201,13 @@ class TestMakeCovarianceMatrix(test.TestCase):
def test_zero_size_matrix(self):
raw = numpy.zeros([0, 0])
- with self.test_session():
+ with self.cached_session():
constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval()
self.assertEqual((0, 0), constructed.shape)
def test_sign_magnitude_positive_definite(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
matrix_tensor = math_utils.sign_magnitude_positive_definite(
raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype),
off_diagonal_scale=constant_op.constant(-1., dtype=dtype),
@@ -230,7 +230,8 @@ class TestLookupTable(test.TestCase):
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
- with self.test_session() as session:
+
+ with self.cached_session() as session:
((float_output, double_output), int_output) = session.run(
hash_table.lookup([2, 1, 0]))
def expected_output_before_insert(base_tensor):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index cfd31cc70d..a049dbe773 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -29,7 +29,7 @@ class ModelUtilsTest(test.TestCase):
def test_parameter_switching(self):
parameter = array_ops.constant(5)
overridden_parameter = array_ops.constant(3)
- with self.test_session():
+ with self.cached_session():
getter = model_utils.parameter_switch({overridden_parameter: 4})
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
index 5f7e3da2db..42ba6e1c25 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
@@ -127,7 +127,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -178,7 +178,7 @@ class ChainingStateManagerTest(test.TestCase):
result_model_outputs = chainer.define_loss(
model=stub_model, features=result_input_fn()[0],
mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -221,7 +221,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
index 53d7340e85..a77c507d9b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
@@ -61,7 +61,7 @@ class FilteringStepPostprocessorTest(test.TestCase):
expected_state = [[[80.], [20.]],
[1., 6.],
[-1, -2]]
- with self.test_session():
+ with self.cached_session():
for interpolated, expected in zip(interpolated_state, expected_state):
self.assertAllClose(expected, interpolated.eval())
self.assertGreater(0., updated_outputs["anomaly_score"][0].eval())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
index 57f29f3c7f..f636126a33 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
@@ -98,7 +98,7 @@ class MultivariateTests(test.TestCase):
observation_model=observation_model,
predicted_observations=(observed_mean, observed_var),
observation_noise=observation_noise_covariance)
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state = numpy.array([[1., 1., 1., 1.]])
evaled_state_var = numpy.eye(4)[None]
for i in range(500):
@@ -136,7 +136,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_observed_from_state(self):
"""Compare observation mean and noise to hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[2., 1.]])
state_var = constant_op.constant([[[4., 0.], [0., 3.]]])
observed_mean, observed_var = self.kalman_filter.observed_from_state(
@@ -171,7 +171,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
observation_model=observation_model,
predicted_observations=predicted_observations,
observation_noise=observation_noise))
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state, evaled_state_var = session.run([state, state_var])
for _ in range(300):
evaled_state, evaled_state_var = session.run(
@@ -231,7 +231,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_mean(self):
"""Compare state mean transitions with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state = self.kalman_filter.predict_state_mean(
state, self.transition_fn([1]))
@@ -245,7 +245,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_var(self):
"""Compare a variance transition with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state_var = constant_op.constant([[[1., 0.], [0., 2.]]])
state_var = self.kalman_filter.predict_state_var(
state_var, self.transition_fn([1]), self.power_sum_fn([1]))
@@ -259,7 +259,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([[
@@ -289,7 +289,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
original_state = constant_op.constant([[4., 2.]])
n = 5
iterative_state = original_state
@@ -304,7 +304,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.transition_fn([1]))
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
original_var = constant_op.constant([[[2., 3.], [4., 5.]]])
n = 5
iterative_var = original_var
@@ -330,7 +330,7 @@ class KalmanFilterBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]])
state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([
@@ -378,7 +378,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertLess(third_log_prob.sum(), numpy.log(0.01))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, _ = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
@@ -396,7 +396,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertAllClose(state2.eval()[2], batch_eval[2])
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index c2eaa78493..80126ac786 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -96,7 +96,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -114,7 +114,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -144,7 +144,7 @@ class GapTests(test.TestCase):
state=math_utils.replicate_state(
start_state=random_model.get_start_state(),
batch_size=array_ops.shape(times)[0]))
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -250,7 +250,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
self.assertAllClose(combined_value, split_predict[prediction_key])
def _equivalent_to_single_model_test_template(self, model_generator):
- with self.test_session() as session:
+ with self.cached_session() as session:
random_model = RandomStateSpaceModel(
state_dimension=5,
state_noise_dimension=4,
@@ -374,7 +374,7 @@ class PredictionTests(test.TestCase):
math_utils.replicate_state(
start_state=random_model.get_start_state(), batch_size=1)
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = prediction_dict["mean"].eval()
predicted_covariance = prediction_dict["covariance"].eval()
@@ -404,7 +404,7 @@ class PredictionTests(test.TestCase):
feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = predictions["mean"].eval()
predicted_covariance = predictions["covariance"].eval()
@@ -428,7 +428,7 @@ class ExogenousTests(test.TestCase):
state=[
array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -454,7 +454,7 @@ class ExogenousTests(test.TestCase):
-array_ops.ones(shape=[1, 5], dtype=dtype),
original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -519,7 +519,7 @@ class PosteriorTests(test.TestCase):
model=stub_model, data=data, true_parameters=true_params)
def test_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(data))
@@ -559,7 +559,7 @@ class PosteriorTests(test.TestCase):
posterior_times)
def test_chained_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
chunk_size = 10
input_fn = test_utils.AllWindowInputFn(
@@ -748,7 +748,7 @@ class MultivariateTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
index 84885d5c9a..e8875f4eb9 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
@@ -46,7 +46,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -65,7 +65,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -85,7 +85,7 @@ class MakeModelTest(test.TestCase):
TrainEvalFeatures.VALUES: constant_op.constant([[[1.], [2.]]])},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index 98cc31f18d..b4b06a40a2 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -142,9 +142,8 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
response.encoded_trace(), os));
}
- if (response.has_op_profile() &&
- (response.op_profile().has_by_program_structure() ||
- response.op_profile().has_by_category())) {
+ if (response.has_op_profile() && (response.op_profile().has_by_program() ||
+ response.op_profile().has_by_category())) {
TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
response.op_profile(), os));
}
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index feb177a7da..68cf510e71 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -4,12 +4,14 @@ package tensorflow.tpu.op_profile;
// Profile is the top-level data that summarizes a program.
message Profile {
+ reserved 2;
+ reserved "by_program_structure";
+ reserved 3;
+ reserved "per_program";
// Root of a profile broken down by instruction category.
Node by_category = 1;
- // Root of a profile broken down by program structure.
- Node by_program_structure = 2;
- // Per program profile, indexed by hlo module name of the program.
- map<string, Node> per_program = 3;
+ // Root of a profile broken down by program.
+ Node by_program = 4;
}
// An entry in the profile tree. (An instruction, or set of instructions).
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index c1f90c3963..0f9f7cd91b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -654,13 +654,16 @@ def split_compile_and_replicate(computation,
# variables.
# Partitioned variables is not supported (b/112311320).
def custom_getter(getter, name, *args, **kwargs):
+ """Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
- if partitioner is None:
- return getter(name, *args, **kwargs)
- else:
- raise ValueError(
+ if partitioner is not None:
+ kwargs["partitioner"] = None
+ logging.warning(
"Partitioned variables are not supported on TPU. Got "
- "`partitioner` that is {}.".format(partitioner))
+ "`partitioner` that is {} for variable {}. "
+ "Setting `partitioner` to `None`."
+ .format(partitioner, name))
+ return getter(name, *args, **kwargs)
vscope = variable_scope.get_variable_scope()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 1ff04f5c26..23c54511ca 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1774,18 +1774,19 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
summary_writer=summary_writer)
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
- global_step_per_sec = elapsed_steps / elapsed_time
- examples_per_sec = self._batch_size * global_step_per_sec
+ global_steps_per_sec = elapsed_steps / elapsed_time
+ examples_per_sec = self._batch_size * global_steps_per_sec
if self._summary_writer is not None:
global_step_summary = Summary(value=[
- Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
+ Summary.Value(tag='global_steps/sec',
+ simple_value=global_steps_per_sec)
])
example_summary = Summary(value=[
Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
])
self._summary_writer.add_summary(global_step_summary, global_step)
self._summary_writer.add_summary(example_summary, global_step)
- logging.info('global_step/sec: %g', global_step_per_sec)
+ logging.info('global_steps/sec: %g', global_steps_per_sec)
logging.info('examples/sec: %g', examples_per_sec)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 79ad3b8e54..1a86bff5cd 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -168,6 +168,7 @@ COMMON_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
"framework/allocation_description.proto",
+ "framework/api_def.proto",
"framework/attr_value.proto",
"framework/cost_graph.proto",
"framework/device_attributes.proto",
@@ -177,9 +178,9 @@ COMMON_PROTO_SRCS = [
"framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
+ "framework/model.proto",
"framework/node_def.proto",
"framework/op_def.proto",
- "framework/api_def.proto",
"framework/reader_base.proto",
"framework/remote_fused_graph_execute_info.proto",
"framework/resource_handle.proto",
@@ -299,6 +300,7 @@ filegroup(
name = "platform_base_hdrs",
srcs = [
"platform/byte_order.h",
+ "platform/cord.h",
"platform/env_time.h",
"platform/logging.h",
"platform/macros.h",
@@ -720,6 +722,7 @@ cc_library(
name = "abi",
srcs = ["platform/abi.cc"],
hdrs = ["platform/abi.h"],
+ deps = [":platform_base"],
)
cc_library(
@@ -839,6 +842,7 @@ tf_cuda_library(
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
+ "framework/model.h",
"framework/node_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_op.h",
@@ -874,7 +878,6 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
- "util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@@ -1918,6 +1921,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "protobuf/config_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/config.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "protobuf/device_properties_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "protobuf/device_properties.proto",
@@ -2056,6 +2066,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"platform/snappy.h",
"platform/tensor_coding.h",
"platform/tracing.h",
+ "util/env_var.h",
]
# Replicated for lib_internal and lib_internal_impl.
@@ -2095,6 +2106,7 @@ cc_library(
"platform/*.cc",
"platform/profile_utils/**/*.cc",
"framework/resource_handle.cc",
+ "util/env_var.cc",
],
exclude = [
"**/*test*",
@@ -2450,7 +2462,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/unique_tensor_references.h",
"framework/variant.h",
"util/command_line_flags.h",
- "util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
@@ -2526,6 +2537,7 @@ tf_cuda_library(
"util/memmapped_file_system_writer.*",
"util/stats_calculator.*",
"util/version_info.cc",
+ "util/env_var.cc",
],
) + select({
"//tensorflow:windows": [],
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000000..cdaeb5091c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+ graph_op_name: "BoostedTreesBucketize"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+ }
+ in_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+ }
+ out_arg {
+ name: "buckets"
+ description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features.
+END
+ }
+ summary: "Bucketize each feature based on bucket boundaries."
+ description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000000..20da1295f6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; Handle to quantile stream resource.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required approximation error of the stream resource.
+END
+ }
+ in_arg {
+ name: "num_streams"
+ description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+ }
+ attr {
+ name: "max_elements"
+ description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+ }
+ summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000000..ca111af312
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "BoostedTreesMakeQuantileSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+ }
+ in_arg {
+ name: "example_weights"
+ description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required maximum approximation error.
+END
+ }
+ out_arg {
+ name: "summaries"
+ description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+ }
+ summary: "Makes the summary of quantiles for the batch."
+ description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000000..bbeecbf32b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "summaries"
+ description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+ }
+ summary: "Add the quantile summaries to each quantile stream resource."
+ description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000000..2fd94efa10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "num_buckets",
+ description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+ }
+ attr {
+ name: "generate_quantiles"
+ description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+ }
+ summary: "Flush the summaries for a quantile stream resource."
+ description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000000..206672802f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ out_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+ }
+ summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+ description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..cb7786c051
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
index e39213cbc7..440800704e 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
@@ -11,7 +11,8 @@ END
name: "record_defaults"
description: <<END
One tensor per column of the input record, with either a
-scalar default value for that column or empty if the column is required.
+scalar default value for that column or an empty vector if the column is
+required.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000000..758eeb96f0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; The reference to quantile stream resource handle.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+ }
+ summary: "Checks whether a quantile stream has been initialized."
+ description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
new file mode 100644
index 0000000000..171add16d4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "ModelDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "Identity transformation that models performance."
+ description: <<END
+Identity transformation that models performance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 8fc1e5cba3..5246090ab3 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -32,8 +32,10 @@ For each string in the input `Tensor`, creates a substring starting at index
If `len` defines a substring that would extend beyond the length of the input
string, then as many characters as possible are used.
-If `pos` is negative or specifies a character index larger than any of the input
-strings, then an `InvalidArgumentError` is thrown.
+A negative `pos` indicates distance within the string backwards from the end.
+
+If `pos` specifies an index which is out of range for any of the input strings,
+then an `InvalidArgumentError` is thrown.
`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
Op creation.
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index f8cb854b52..cf3d1f0b79 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -358,7 +358,7 @@ static Status WrappedTensorDeviceCopy(
#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+ Tensor, DIRECTION, WrappedTensorDeviceCopy)
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index eb388202fa..b4d8e285bd 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1228,7 +1228,7 @@ Status DirectSession::CreateExecutors(
}
};
- optimizer.Optimize(lib, options_.env, device, &iter->second,
+ optimizer.Optimize(lib, options_.env, device, &partition_graph,
/*shape_map=*/nullptr);
// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 3f2355e530..65e816c202 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1255,7 +1255,7 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -1308,7 +1308,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 879a794368..263467a5b6 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -56,6 +56,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
+ log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
use_send_tensor_rpc_(false) {
if (device_mgr_owned) {
@@ -65,13 +66,9 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_unowned_device_manager_ = device_mgr;
}
InitDeviceMapAndAsync();
- if (opts.config.inter_op_parallelism_threads() > 0) {
- runner_ = [this](std::function<void()> closure) {
- this->thread_pool_->Schedule(closure);
- };
- } else {
- runner_ = [](std::function<void()> closure) { closure(); };
- }
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(closure);
+ };
}
void EagerContext::InitDeviceMapAndAsync() {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index eb6eb0d55a..5ed6057ec6 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#endif
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -141,6 +142,7 @@ class EagerContext {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() { return log_device_placement_; }
+ bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; }
@@ -261,6 +263,8 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const bool log_memory_;
+
Env* const env_;
#ifndef __ANDROID__
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 5b3a64ba98..1da1326a9a 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -296,7 +296,7 @@ Status EagerLocalExecute(EagerOperation* op,
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new KernelAndDevice(ctx->GetRendezvous());
+ kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
auto* flr = ctx->func_lib(device);
if (flr == nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3d61ff4dc2..83d8425477 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -32,21 +32,6 @@ limitations under the License.
namespace tensorflow {
// static
-Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out) {
- OpKernel* k = nullptr;
- Status s = CreateOpKernel(device->device_type().c_str(), device,
- device->GetAllocator(AllocatorAttributes()),
- nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
- out->device_ = device;
- out->kernel_.reset(k);
- out->flib_ = nullptr;
- out->runner_ = nullptr;
- out->default_runner_ = [](std::function<void()> f) { f(); };
- return s;
-}
-
-// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
@@ -95,6 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
params.cancellation_manager = &cm_;
+ params.log_memory = log_memory_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 0ef419cbaa..04151a1171 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -52,12 +52,12 @@ class KernelAndDevice {
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
- // TODO(ashankar): Remove this
- static Status InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out);
- KernelAndDevice(tensorflow::Rendezvous* rendez)
- : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
+ KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
+ : device_(nullptr),
+ flib_(nullptr),
+ rendez_(rendez),
+ log_memory_(log_memory) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@@ -87,6 +87,7 @@ class KernelAndDevice {
DataTypeVector output_dtypes_;
std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_;
+ const bool log_memory_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 6abe98f53c..da280b2317 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) {
.NumInputs(2)
.BuildNodeDef());
TestEnv env;
- KernelAndDevice k(nullptr);
+ KernelAndDevice k(nullptr, false);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
@@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) {
.NumInputs(inputs.size())
.BuildNodeDef());
TestEnv env;
- KernelAndDevice kernel(nullptr);
+ KernelAndDevice kernel(nullptr, false);
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
nullptr, &kernel));
tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 7f260b3139..4475fa979e 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -561,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item;
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
+ // TODO(b/114748242): Add a unit test to test this bug fix.
+ if (flib_def_) {
+ *item.graph.mutable_library() = flib_def_->ToProto();
+ }
item.fetch.insert(item.fetch.end(),
options.callable_options.fetch().begin(),
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 6b76e7e0e7..df9c3a686c 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -24,9 +24,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -48,6 +50,125 @@ class MklSubAllocator : public SubAllocator {
void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
};
+// CPU allocator that handles small-size allocations by calling
+// suballocator directly. Mostly, it is just a wrapper around a suballocator
+// (that calls malloc and free directly) with support for bookkeeping.
+class MklSmallSizeAllocator : public VisitableAllocator {
+ public:
+ MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ const string& name)
+ : sub_allocator_(sub_allocator), name_(name) {
+ stats_.bytes_limit = total_memory;
+ }
+ ~MklSmallSizeAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator);
+
+ inline string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* ptr = sub_allocator_->Alloc(alignment, num_bytes);
+ if (ptr != nullptr) {
+ std::pair<void*, size_t> map_val(ptr, num_bytes);
+ mutex_lock l(mutex_);
+ // Check that insertion in the hash map was successful.
+ CHECK(map_.insert(map_val).second);
+ // Increment statistics for small-size allocations.
+ IncrementStats(num_bytes);
+ // Call alloc visitors.
+ for (const auto& visitor : alloc_visitors_) {
+ visitor(ptr, num_bytes);
+ }
+ }
+ return ptr;
+ }
+
+ void DeallocateRaw(void* ptr) override {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(mutex_);
+ auto map_iter = map_.find(ptr);
+ if (map_iter != map_.end()) {
+ // Call free visitors.
+ size_t dealloc_bytes = map_iter->second;
+ for (const auto& visitor : free_visitors_) {
+ visitor(ptr, dealloc_bytes);
+ }
+ sub_allocator_->Free(ptr, dealloc_bytes);
+ DecrementStats(dealloc_bytes);
+ map_.erase(map_iter);
+ } else {
+ LOG(ERROR) << "tried to deallocate invalid pointer";
+ return;
+ }
+ }
+
+ inline bool IsSmallSizeAllocation(const void* ptr) const {
+ mutex_lock l(mutex_);
+ return map_.find(ptr) != map_.end();
+ }
+
+ void GetStats(AllocatorStats* stats) override {
+ mutex_lock l(mutex_);
+ *stats = stats_;
+ }
+
+ void ClearStats() override {
+ mutex_lock l(mutex_);
+ stats_.Clear();
+ }
+
+ void AddAllocVisitor(Visitor visitor) override {
+ mutex_lock l(mutex_);
+ alloc_visitors_.push_back(visitor);
+ }
+
+ void AddFreeVisitor(Visitor visitor) override {
+ mutex_lock l(mutex_);
+ free_visitors_.push_back(visitor);
+ }
+
+ private:
+ // Increment statistics for the allocator handling small allocations.
+ inline void IncrementStats(size_t alloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ ++stats_.num_allocs;
+ stats_.bytes_in_use += alloc_size;
+ stats_.max_bytes_in_use =
+ std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+ stats_.max_alloc_size =
+ std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size));
+ }
+
+ // Decrement statistics for the allocator handling small allocations.
+ inline void DecrementStats(size_t dealloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ stats_.bytes_in_use -= dealloc_size;
+ }
+
+ SubAllocator* sub_allocator_; // Not owned by this class.
+
+ // Mutex for protecting updates to map of allocations.
+ mutable mutex mutex_;
+
+ // Allocator name
+ string name_;
+
+ // Hash map to keep track of "small" allocations
+ // We do not use BFC allocator for small allocations.
+ std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_);
+
+ // Allocator stats for small allocs
+ AllocatorStats stats_ GUARDED_BY(mutex_);
+
+ // Visitors
+ std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
+ std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
+};
+
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
class MklCPUAllocator : public VisitableAllocator {
@@ -62,7 +183,10 @@ class MklCPUAllocator : public VisitableAllocator {
MklCPUAllocator() { TF_CHECK_OK(Initialize()); }
- ~MklCPUAllocator() override { delete allocator_; }
+ ~MklCPUAllocator() override {
+ delete small_size_allocator_;
+ delete large_size_allocator_;
+ }
Status Initialize() {
VLOG(2) << "MklCPUAllocator: In MklCPUAllocator";
@@ -96,8 +220,15 @@ class MklCPUAllocator : public VisitableAllocator {
}
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
- allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
- kAllowGrowth, kName);
+
+ sub_allocator_ = new MklSubAllocator();
+
+ // SubAllocator is owned by BFCAllocator, so we do not need to deallocate
+ // it in MklSmallSizeAllocator.
+ small_size_allocator_ =
+ new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName);
+ large_size_allocator_ =
+ new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName);
#ifndef INTEL_MKL_DNN_ONLY
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
@@ -112,23 +243,55 @@ class MklCPUAllocator : public VisitableAllocator {
inline string Name() override { return kName; }
inline void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return allocator_->AllocateRaw(alignment, num_bytes);
+ // If the allocation size is less than threshold, call small allocator,
+ // otherwise call large-size allocator (BFC). We found that BFC allocator
+ // does not deliver good performance for small allocations when
+ // inter_op_parallelism_threads is high.
+ return (num_bytes < kSmallAllocationsThreshold)
+ ? small_size_allocator_->AllocateRaw(alignment, num_bytes)
+ : large_size_allocator_->AllocateRaw(alignment, num_bytes);
}
inline void DeallocateRaw(void* ptr) override {
- allocator_->DeallocateRaw(ptr);
+ // Check if ptr is for "small" allocation. If it is, then call Free
+ // directly. Otherwise, call BFC to handle free.
+ if (small_size_allocator_->IsSmallSizeAllocation(ptr)) {
+ small_size_allocator_->DeallocateRaw(ptr);
+ } else {
+ large_size_allocator_->DeallocateRaw(ptr);
+ }
}
- void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); }
+ void GetStats(AllocatorStats* stats) override {
+ AllocatorStats l_stats, s_stats;
+ small_size_allocator_->GetStats(&s_stats);
+ large_size_allocator_->GetStats(&l_stats);
+
+ // Combine statistics from small-size and large-size allocator.
+ stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs;
+ stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use;
+ stats->max_bytes_in_use =
+ l_stats.max_bytes_in_use + s_stats.max_bytes_in_use;
+
+ // Since small-size allocations go to MklSmallSizeAllocator,
+ // max_alloc_size from large_size_allocator would be the maximum
+ // size allocated by MklCPUAllocator.
+ stats->max_alloc_size = l_stats.max_alloc_size;
+ }
- void ClearStats() override { allocator_->ClearStats(); }
+ void ClearStats() override {
+ small_size_allocator_->ClearStats();
+ large_size_allocator_->ClearStats();
+ }
void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
+ small_size_allocator_->AddAllocVisitor(visitor);
+ large_size_allocator_->AddAllocVisitor(visitor);
}
void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
+ small_size_allocator_->AddFreeVisitor(visitor);
+ large_size_allocator_->AddFreeVisitor(visitor);
}
private:
@@ -148,26 +311,33 @@ class MklCPUAllocator : public VisitableAllocator {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
- return nullptr; // return a value and make static code analyzers happy
+ return nullptr; // return a value and make static code analyzers happy
}
static inline void* ReallocHook(void* ptr, size_t size) {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
- return nullptr; // return a value and make static code analyzers happy
+ return nullptr; // return a value and make static code analyzers happy
}
- /// Do we allow growth in BFC Allocator
+ // Do we allow growth in BFC Allocator
static const bool kAllowGrowth = true;
- /// Name
+ // Name
static constexpr const char* kName = "mklcpu";
- /// The alignment that we need for the allocations
+ // The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* allocator_; // owned by this class
+ VisitableAllocator* large_size_allocator_; // owned by this class
+ MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
+
+ SubAllocator* sub_allocator_; // not owned by this class
+
+ // Size in bytes that defines the upper-bound for "small" allocations.
+ // Any allocation below this threshold is "small" allocation.
+ static constexpr const size_t kSmallAllocationsThreshold = 4096;
// Prevent copying and assignment
TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 1e3fed0d6f..43ca3f1e3e 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/reffed_status_callback.h"
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
index 04d5af9087..22650b0d83 100644
--- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index e7142a4ef9..e36e51d8d5 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -199,7 +199,13 @@ message Example {
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
-//
+// - For sequence modeling, e.g.:
+// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
+// https://github.com/tensorflow/nmt
+// the feature lists represent a sequence of frames.
+// In this scenario, all FeatureLists in a SequenceExample have the same
+// number of Feature messages, so that the ith element in each FeatureList
+// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c57b..2a7ee16a16 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe137..ded120b704 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,13 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Variant;
+
// Attributes for a single allocation call. Different calls to the same
// allocator could potentially have different allocation attributes.
struct AllocationAttributes {
@@ -228,13 +229,9 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
- virtual void RunVariantCtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
- }
+ virtual void RunVariantCtor(Variant* p, size_t n);
- virtual void RunVariantDtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
- }
+ virtual void RunVariantDtor(Variant* p, size_t n);
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282ce84..e907c52ba9 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a3994736c..4ffd732f8e 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4e51fba048..4ee6749eea 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -291,6 +292,9 @@ class IteratorContext {
// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+ // If non-null, identifies the object used for performance modeling.
+ std::shared_ptr<model::Model> model = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -342,6 +346,10 @@ class IteratorContext {
return params_.stats_aggregator_getter;
}
+ std::shared_ptr<model::Model> model() { return params_.model; }
+
+ Params params() { return params_; }
+
private:
Params params_;
};
@@ -376,7 +384,11 @@ class SerializationContext {
// defined below.
class IteratorBase {
public:
- virtual ~IteratorBase() {}
+ virtual ~IteratorBase() {
+ for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+ (*rit)();
+ }
+ }
// Gets the next output from the range that this iterator is traversing.
//
@@ -410,6 +422,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Returns a string that identifies the sequence of iterators leading up to
+ // this iterator.
+ virtual const string& prefix() const = 0;
+
// Performs initialization that needs to happen outside of a constructor to
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -449,6 +465,18 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ private:
+ friend class DatasetBase; // for access to `AddCleanupFunction`
+
+ // Registers a cleanup function to be called upon object destruction.
+ //
+ // Registered functions are invoked in the reserve order of registration.
+ void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+ cleanup_fns_.push_back(std::move(cleanup_fn));
+ }
+
+ std::vector<std::function<void()>> cleanup_fns_;
};
// Represents runtime information needed to construct a dataset.
@@ -498,6 +526,27 @@ class DatasetBase : public core::RefCounted {
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
+ if (ctx->model()) {
+ // The prefix might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_prefix = prefix;
+ if (str_util::EndsWith(prefix, "]")) {
+ sanitized_prefix = prefix.substr(0, prefix.rfind('['));
+ }
+ std::shared_ptr<model::Node> node =
+ ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
+ std::vector<string> tokens =
+ str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
+ node->set_name(tokens[tokens.size() - 1]);
+ std::shared_ptr<model::Model> model = ctx->model();
+ const string& prefix = (*iterator)->prefix();
+ (*iterator)->AddCleanupFunction([model, node, prefix]() {
+ if (node->output()) {
+ node->output()->remove_input(node);
+ }
+ model->RemoveNode(prefix);
+ });
+ }
return (*iterator)->Initialize(ctx);
}
@@ -524,6 +573,8 @@ class DatasetBase : public core::RefCounted {
IteratorStateWriter* writer) const;
protected:
+ friend class DatasetToGraphOp; // For access to graph related members.
+
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -541,8 +592,6 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
- friend class DatasetToGraphOp; // For access to graph related members.
-
private:
const string name_;
};
@@ -565,7 +614,7 @@ class DatasetBaseIterator : public IteratorBase {
~DatasetBaseIterator() override { params_.dataset->Unref(); }
// The sequence of iterators leading up to this iterator.
- const string& prefix() const { return params_.prefix; }
+ const string& prefix() const override { return params_.prefix; }
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();
@@ -578,7 +627,23 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ Status s;
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node =
+ ctx->model()->LookupNode(params_.prefix);
+ if (node->output()) {
+ node->output()->stop_work();
+ }
+ node->start_work();
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ node->stop_work();
+ node->add_element();
+ if (node->output()) {
+ node->output()->start_work();
+ }
+ } else {
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ }
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -605,6 +670,39 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
+ // When performance modeling is enabled, this method sets metadata entry for
+ // the model node corresponding to this iterator.
+ void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->set_metadata(key, value);
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has started work.
+ void StartWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->start_work();
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has stopped work.
+ void StopWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->stop_work();
+ }
+ }
+ }
+
private:
BaseParams params_;
};
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 26f32677af..d979353d2f 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+ tf_shared_lock l(mu_);
+ int index = 0;
+ string name = strings::StrCat(prefix, index);
+ while (function_defs_.find(name) != function_defs_.end()) {
+ ++index;
+ name = strings::StrCat(prefix, index);
+ }
+ return name;
+}
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 03296a7761..e01eb7503d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
+ // Generates new function name with the specified prefix that is unique
+ // across this library.
+ string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 46b169dddc..c5a4f661d2 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -110,6 +110,22 @@ FunctionDef XTimesTwo() {
});
}
+FunctionDef XAddX() {
+ return FDH::Define(
+ // Name
+ "XAddX",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
+ });
+}
+
FunctionDef XTimesTwoInt32() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index 6d6476b936..ad61a76f16 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
// x:T -> x * 2.
FunctionDef XTimesTwo();
+// x:T -> x + x.
+FunctionDef XAddX();
+
// x:T -> x * 2, where x is int32.
FunctionDef XTimesTwoInt32();
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000000..250b006641
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,396 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
+ mutex_lock l(mu_);
+ switch (type_) {
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ int64 processing_time = static_cast<int64>(
+ static_cast<double>(ProcessingTimeLocked() -
+ inputs_.front()->ProcessingTime()) /
+ static_cast<double>(inputs_.size() - 1));
+ knobs->emplace_back(
+ Node::Knob{this, processing_time, metadata_["parallelism"]});
+ return;
+ }
+ case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_MAP: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ knobs->emplace_back(
+ Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
+ return;
+ }
+ case Type::BATCH:
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::FILTER:
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::MAP:
+ case Type::PADDED_BATCH:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PREFETCH:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ return;
+ }
+ default:
+ return;
+ }
+}
+
+int64 Node::ProcessingTimeLocked() {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::MAP_AND_BATCH:
+ case Type::PADDED_BATCH: {
+ int64 batch_size = metadata_["batch_size"];
+ return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ return NanosPerElementLocked() +
+ static_cast<int64>(ratio *
+ static_cast<double>(ProcessingTimeForInputs()));
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 processing_time =
+ ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+ return NanosPerElementLocked() +
+ static_cast<double>(processing_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::PARALLEL_MAP:
+ case Type::PREFETCH:
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ return NanosPerElementLocked() + ProcessingTimeForInputs();
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::PADDED_BATCH: {
+ double batch_size = metadata_["batch_size"];
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) /
+ batch_size);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ batch_size * OutputTimeForInputs(input_times);
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ static_cast<int64>(
+ static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ int64 output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ return NanosPerElementLocked() +
+ static_cast<double>(output_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::MAP_AND_BATCH: {
+ double batch_size = metadata_["batch_size"];
+ double parallelism = metadata_["parallelism"];
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+ (batch_size * parallelism));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ batch_size * OutputTimeForInputs(input_times));
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(metadata_["parallelism"]));
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_MAP: {
+ double parallelism = std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(metadata_["parallelism"]));
+ int64 delta = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time =
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ OutputTimeForInputs(input_times);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PREFETCH: {
+ int64 delta = NanosPerElementLocked();
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ return std::max(0LL, NanosPerElementLocked() +
+ OutputTimeForInputs(input_times) -
+ input_times->at(input_times->size() - 2));
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ int64 delta = NanosPerElementLocked();
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+Model::Model(const proto::Model& model_proto) {
+ id_counter_ = model_proto.id_counter();
+ std::map<int64, std::shared_ptr<Node>> lookup_table;
+ for (auto node_proto : model_proto.node()) {
+ std::shared_ptr<Node> node(new Node(node_proto));
+ lookup_table[node_proto.id()] = node;
+ }
+ for (auto node_proto : model_proto.node()) {
+ std::shared_ptr<Node> node = lookup_table[node_proto.id()];
+ for (int64 id : node_proto.input()) {
+ node->add_input(lookup_table[id]);
+ }
+ node->set_output(lookup_table[node_proto.output()]);
+ }
+ output_ = lookup_table[model_proto.output()];
+}
+
+std::shared_ptr<Node> Model::AddNode(const string& name,
+ const string& output_name) {
+ mutex_lock l(mu_);
+ std::shared_ptr<Node> output;
+ auto it = lookup_table_.find(output_name);
+ if (it != lookup_table_.end()) {
+ output = it->second;
+ }
+ std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ if (!output_) {
+ output_ = node;
+ }
+ if (output) {
+ output->add_input(node);
+ }
+ lookup_table_.insert(std::make_pair(name, node));
+ return node;
+}
+
+std::shared_ptr<Node> Model::LookupNode(const string& name) {
+ tf_shared_lock l(mu_);
+ std::shared_ptr<Node> result;
+ auto it = lookup_table_.find(name);
+ if (it != lookup_table_.end()) {
+ result = it->second;
+ }
+ return result;
+}
+
+void Model::Optimize() {
+ mutex_lock l(mu_);
+ int64 processing_time = ProcessingTime();
+ int64 num_cpus = port::NumSchedulableCPUs();
+ std::vector<Node::Knob> knobs = CollectKnobs();
+ // The optimization algorithm starts by setting all parallelism knobs to 1. It
+ // then repeatedly identifies the knob that, when turned up by 1, decreases
+ // the output time the most. This process is repeated until all knobs reach
+ // the number of schedulable CPUs or the projected output time is less than or
+ // equal to the processing time needed to produce an element divided by the
+ // number of schedulable CPUs.
+ for (auto& knob : knobs) {
+ LOG(INFO) << knob.node->name() << " " << knob.processing_time;
+ knob.value = 1;
+ knob.node->set_metadata("parallelism", knob.value);
+ }
+ while (true) {
+ int64 output_time = OutputTime();
+ bool all_knobs = true;
+ for (auto knob : knobs) {
+ if (knob.value < num_cpus) {
+ all_knobs = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / num_cpus || all_knobs) {
+ break;
+ }
+ int64 best_delta = -1;
+ int best_knob = -1;
+ for (int i = 0; i < knobs.size(); ++i) {
+ if (knobs[i].value == num_cpus) {
+ continue;
+ }
+ knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_knob = i;
+ }
+ knobs[i].node->set_metadata("parallelism", knobs[i].value);
+ }
+ knobs[best_knob].value++;
+ knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
+ }
+ for (auto knob : knobs) {
+ LOG(INFO) << knob.node->name() << " " << knob.value;
+ }
+ LOG(INFO) << "output time: " << OutputTime();
+ LOG(INFO) << "processing time: " << ProcessingTime();
+}
+
+void Model::OutputToFile() {
+ proto::Model model_proto;
+ ToProto(&model_proto);
+ string filename;
+ Env::Default()->LocalTempFilename(&filename);
+ TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
+ model_proto.SerializeAsString()));
+ LOG(INFO) << filename;
+}
+
+void Model::RemoveNode(const string& prefix) {
+ mutex_lock l(mu_);
+ lookup_table_.erase(prefix);
+}
+
+void Model::ToProto(proto::Model* model_proto) {
+ mutex_lock l(mu_);
+ model_proto->set_id_counter(id_counter_);
+ model_proto->set_output(output_->id());
+ AddNodeToProto(output_, model_proto);
+}
+
+// static
+void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
+ proto::Model* model_proto) {
+ proto::Node* node_proto = model_proto->add_node();
+ node->ToProto(node_proto);
+ for (const std::shared_ptr<Node>& input : node->inputs()) {
+ AddNodeToProto(input, model_proto);
+ }
+}
+
+std::vector<Node::Knob> Model::CollectKnobs() {
+ std::vector<Node::Knob> knobs;
+ output_->CollectKnobs(&knobs);
+ return knobs;
+}
+
+int64 Model::OutputTime() {
+ std::vector<int64> input_times(1, 0);
+ return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000000..98172909bf
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,396 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread> // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/model.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+class Model;
+class Node;
+
+// Abstract representation of a TensorFlow input pipeline node. It collects
+// information about inputs to this node, processing time spent executing the
+// node logic, number of elements produced by the node, various other
+// information (e.g. batch size or execution parallelism).
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting common information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// In addition, `DatasetBaseIterator` provides wrappers that can be used for
+// transformation-specific information collection. The `SetMetadata` wrapper can
+// be used to pass arbitrary metadata to the modeling framework, while the
+// `StartWork` and `StopWork` wrappers should be used to correctly account for
+// processing time of multi-threaded transformation that yield the CPU; such
+// transformations should invoke `StartWork()` when a transformation thread
+// starts executing (e.g. when created or woken up) and `StopWork()` when a
+// transformation thread stops executing (e.g. when returning or waiting).
+//
+// TODO(jsimsa): Create an API to capture the abstract semantics of each
+// tf.data transformation and replace switch-case blocks with inheritance.
+class Node {
+ public:
+ Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
+
+ explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
+ name_ = node_proto.name();
+ type_ = TypeFromName(node_proto.name());
+ processing_time_ = node_proto.processing_time();
+ num_elements_ = node_proto.num_elements();
+ metadata_.insert(node_proto.metadata().begin(),
+ node_proto.metadata().end());
+ }
+
+ // Records that the node produced an element.
+ void add_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Adds the given key-value pair to the node metadata.
+ void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ metadata_[key] = value;
+ }
+
+ // Sets the node name.
+ void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ name_ = name;
+ type_ = TypeFromName(name);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Records that a node thread has started work.
+ void start_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped work.
+ void stop_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ auto iter = work_start_.find(std::this_thread::get_id());
+ CHECK(work_start_.end() != iter)
+ << "Encountered a stop event that was not preceded by a start event.";
+ processing_time_ += Env::Default()->NowNanos() - iter->second;
+ work_start_.erase(iter);
+ }
+
+ private:
+ // Represents a performance knob.
+ struct Knob {
+ Node* node;
+ int64 processing_time;
+ int64 value;
+ };
+
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Collects performance knobs in the subtree rooted in this node.
+ void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in this
+ // node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTimeLocked();
+ }
+ return sum;
+ }
+
+ // Serializes the node state into the given proto.
+ void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ node_proto->set_id(id_);
+ node_proto->set_name(name_);
+ node_proto->set_num_elements(num_elements_);
+ node_proto->set_processing_time(processing_time_);
+ for (const std::shared_ptr<Node>& input : inputs_) {
+ node_proto->add_input(input->id());
+ }
+ if (output_) {
+ node_proto->set_output(output_->id());
+ }
+ node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
+ }
+
+ Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ Type type_ GUARDED_BY(mu_);
+ string name_ GUARDED_BY(mu_);
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> metadata_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+
+ friend class Model;
+};
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of performance knobs.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
+// into the input pipeline.
+class Model {
+ public:
+ Model() = default;
+ explicit Model(const proto::Model& model_proto);
+
+ ~Model() {}
+
+ // Returns the model output node.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Adds a node with the given name and given output (identified by name).
+ std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Looks up the node using the given name.
+ std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Runs optimization.
+ void Optimize() LOCKS_EXCLUDED(mu_);
+
+ // Outputs the state of a model to a file.
+ //
+ // TODO(jsimsa): Remove this method once the optimization loop is closed.
+ void OutputToFile() LOCKS_EXCLUDED(mu_);
+
+ // Removes the node identified by the given name.
+ void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+
+ // Serializes the model state to the given proto.
+ void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
+
+ private:
+ static void AddNodeToProto(const std::shared_ptr<Node>& node,
+ proto::Model* model_proto);
+
+ std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ mutex mu_;
+ int64 id_counter_ GUARDED_BY(mu_) = 1;
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
new file mode 100644
index 0000000000..26000007af
--- /dev/null
+++ b/tensorflow/core/framework/model.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package tensorflow.data.model.proto;
+option cc_enable_arenas = true;
+
+message Model {
+ // Counter used for generating new node IDs.
+ int64 id_counter = 1;
+ // Nodes of this model.
+ repeated Node node = 2;
+ // The ID of the output node.
+ int64 output = 3;
+};
+
+message Node {
+ // The node ID.
+ int64 id = 1;
+ // The node name.
+ string name = 2;
+ // Input node IDs.
+ repeated int64 input = 3;
+ // Output node ID.
+ int64 output = 4;
+ // Number of elements produced by the node.
+ int64 num_elements = 5;
+ // The CPU time spent by running threads of this node.
+ int64 processing_time = 6;
+ // Key-value store for node metadata (e.g. batch size or parallelism).
+ map<string, int32> metadata = 7;
+};
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 0a19861efd..ebdaaec153 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const {
"]");
}
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index f8a587c9b5..d58deaa3fc 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -79,7 +79,7 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
// Returns memory used by this resource.
- virtual int64 MemoryUsed() const { return 0; };
+ virtual int64 MemoryUsed() const { return 0; }
};
// Container used for per-step resources.
@@ -234,7 +234,7 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name);
// Returns a resource handle from a numbered op input.
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
@@ -348,6 +348,8 @@ class ResourceHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
private:
string container_;
string name_;
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5da3..696fd277cd 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@ namespace tensorflow {
class AllocationDescription;
class Allocator;
class OpKernelContext;
+class Tensor;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
-class VariantTensorData;
+
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c196..9a78cdc91e 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/math/math_util.h"
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 4bda8f9eb8..a7cf600bab 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 15b1add2c1..2e96b05787 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@ limitations under the License.
namespace tensorflow {
+class Variant;
+
// MemoryType is used to describe whether input or output Tensors of
// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a507804b0..d43e3c72ec 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@ limitations under the License.
namespace tensorflow {
-bool Variant::TryDecode(Variant* out) const {
- const VariantTensorDataProto* p = get<VariantTensorDataProto>();
- if (p == nullptr) return false;
- VariantTensorData data(*p);
- return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+ if (!is_empty()) {
+ return value_->Decode(std::move(data));
+ }
+ return true;
}
template <>
@@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) {
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data) {
- data->FromProto(value);
+ data->FromConstProto(value);
}
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value) {
- data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+ data->ToProto(value);
return true;
}
@@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
}
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
- return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(*buf);
}
void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
if (variant_array[i].is_empty()) {
variant_array[i] = VariantTensorDataProto();
}
+ // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
+ // zero-copy operation that keeps a reference to the data in d?
string str(d->Data(sizes[i]), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
+ if (!variant_array[i].Decode(std::move(str))) return false;
if (!DecodeUnaryVariant(&variant_array[i])) {
LOG(ERROR) << "Could not decode variant with type_name: \""
<< variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index 52732801a0..10eabbc85f 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -23,7 +23,6 @@ limitations under the License.
#include <unordered_map>
#include <utility>
-#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"
@@ -38,17 +37,19 @@ string TypeNameVariant(const T& value);
template <typename T>
string DebugStringVariant(const T& value);
+// Allows for specializations of Variant Decoding. `data` may be modified in
+// the process of decoding to `value`.
template <typename T>
-void EncodeVariant(const T& value, VariantTensorData* data);
+bool DecodeVariant(VariantTensorData* data, T* value);
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
+bool DecodeVariant(string* buf, T* value);
template <typename T>
-void EncodeVariant(const T& value, string* buf);
+void EncodeVariant(const T& value, VariantTensorData* data);
template <typename T>
-bool DecodeVariant(const string& buf, T* value);
+void EncodeVariant(const T& value, string* buf);
// This is an implementation of a type-erased container that can store an
// object of any type. The implementation is very similar to std::any, but has
@@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value);
//
// string TypeName() const;
// void Encode(VariantTensorData* data) const;
-// void Decode(const VariantTensorData& data);
+// void Decode(VariantTensorData data);
//
// Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods.
@@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value);
// x.Encode(&serialized_f);
//
// Variant y = Foo(); // default constructed Foo.
-// y.Decode(&serialized_f);
+// y.Decode(std::move(serialized_f));
// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
//
//
@@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value);
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
// y_type_unknown.TypeId());
-// // Decode and get y_type_unknown; compare to value in x.
-// Foo f_decoded;
-// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-// EXPECT_EQ(f_decoded, f);
//
class Variant {
public:
@@ -241,12 +238,7 @@ class Variant {
}
// Deserialize `data` and update the stored object.
- bool Decode(const VariantTensorData& data) {
- if (!is_empty()) {
- return value_->Decode(data);
- }
- return true;
- }
+ bool Decode(VariantTensorData data);
// Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const {
@@ -254,31 +246,13 @@ class Variant {
value_->Encode(buf);
}
}
- bool Decode(const string& buf) {
+ bool Decode(string buf) {
if (!is_empty()) {
- return value_->Decode(buf);
+ return value_->Decode(std::move(buf));
}
return true;
}
- template <typename T>
- bool MaybeDecodeAndCopy(T* out) const {
- const T* ret = get<T>();
- if (ret != nullptr) {
- *out = std::move(*ret);
- return true;
- };
- Variant decoded = T();
- if (!TryDecode(&decoded)) return false;
- T* decoded_ret = decoded.get<T>();
- CHECK_NOTNULL(decoded_ret);
- *out = std::move(*decoded_ret);
- return true;
- }
-
- private:
- bool TryDecode(Variant* out) const;
-
private:
struct in_place_t {};
static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@ class Variant {
virtual string TypeName() const = 0;
virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0;
- virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual bool Decode(VariantTensorData data) = 0;
virtual void Encode(string* buf) const = 0;
- virtual bool Decode(const string& data) = 0;
+ virtual bool Decode(string data) = 0;
};
template <typename T>
@@ -325,15 +299,13 @@ class Variant {
EncodeVariant(value, data);
}
- bool Decode(const VariantTensorData& data) override {
- return DecodeVariant(data, &value);
+ bool Decode(VariantTensorData data) override {
+ return DecodeVariant(&data, &value);
}
void Encode(string* buf) const override { EncodeVariant(value, buf); }
- bool Decode(const string& buf) override {
- return DecodeVariant(buf, &value);
- }
+ bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
T value;
};
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index f155aa4892..5e08e5a7a6 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value,
// Specialization for POD type
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for tensorflow::Tensor
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for protobuf
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
@@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for other types
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
- return value->Decode(data);
+ return value->Decode(std::move(data));
}
template <typename C, typename = void>
@@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) {
}
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
- return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+ return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
@@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) {
}
template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
- if (!data.ParseFromString(buf)) return false;
- if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ if (!data.ParseFromString(*buf)) return false;
+ if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+ return false;
+ }
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
+
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd559..daa744e877 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+ StoredTensorValue::CopyCPUToGPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
- "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+ StoredTensorValue::CopyGPUToCPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+ StoredTensorValue::CopyGPUToGPU);
REGISTER_OP("CreateTestVariant")
.Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1aee..ef5b240aea 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
- StringPiece type_name) {
- auto found = shape_fns.find(type_name);
+ const TypeIndex& type_index) {
+ auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
- VariantShapeFn* existing = GetShapeFn(type_name);
+ VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantShapeFn for type_name: " << type_name
- << " already registered";
- shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
- GetPersistentStringPiece(type_name), shape_fn));
+ << "Unary VariantShapeFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+ shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
- "No unary variant shape function found for Variant type_name: ",
- v.TypeName());
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}
@@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) {
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
@@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double);
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
UnaryVariantOpRegistry::GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name) {
- auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+ auto found = device_copy_fns.find(std::make_pair(direction, type_index));
if (found == device_copy_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
- AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+ AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
CHECK_EQ(existing, nullptr)
<< "UnaryVariantDeviceCopy for direction: " << direction
- << " and type_name: " << type_name << " already registered";
+ << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+ << " already registered";
device_copy_fns.insert(
- std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn>(
- std::make_pair(direction, GetPersistentStringPiece(type_name)),
- device_copy_fn));
+ std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+ device_copy_fn));
}
Status VariantDeviceCopy(
@@ -170,35 +167,34 @@ Status VariantDeviceCopy(
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
- from.TypeName());
+ from.TypeId());
if (device_copy_fn == nullptr) {
return errors::Internal(
"No unary variant device copy function found for direction: ",
- direction, " and Variant type_name: ", from.TypeName());
+ direction, " and Variant type_index: ",
+ port::MaybeAbiDemangle(from.TypeId().name()));
}
return (*device_copy_fn)(from, to, copy_fn);
}
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
- VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find({op, device, type_name});
+ VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+ auto found = unary_op_fns.find({op, device, type_index});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
- VariantUnaryOp op, const string& device, const string& type_name,
+ VariantUnaryOp op, const string& device, const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
- VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+ VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantUnaryOpFn for type_name: " << type_name
+ << "Unary VariantUnaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- unary_op_fn));
+ {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
}
namespace {
@@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
- DEVICE_CPU, T, TF_STR(T), \
+ DEVICE_CPU, T, \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name) {
- auto found = binary_op_fns.find({op, device, type_name});
+ const TypeIndex& type_index) {
+ auto found = binary_op_fns.find({op, device, type_index});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
- VariantBinaryOp op, const string& device, const string& type_name,
+ VariantBinaryOp op, const string& device, const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
- VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+ VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantBinaryOpFn for type_name: " << type_name
+ << "Unary VariantBinaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- add_fn));
+ {op, GetPersistentStringPiece(device), type_index}, add_fn));
}
namespace {
@@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
- T, TF_STR(T), \
- AddVariantPrimitiveType<T>);
+ T, AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index e6a2665a56..7eb37e859f 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -22,10 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
@@ -90,10 +94,11 @@ class UnaryVariantOpRegistry {
AsyncVariantDeviceCopyFn;
// Add a shape lookup function to the registry.
- void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+ void RegisterShapeFn(const TypeIndex& type_index,
+ const VariantShapeFn& shape_fn);
- // Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(StringPiece type_name);
+ // Returns nullptr if no shape function was found for the given TypeIndex.
+ VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@ class UnaryVariantOpRegistry {
// Add a copy-to-GPU function to the registry.
void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
- const string& type_name,
+ const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn);
// Returns nullptr if no copy function was found for the given
// TypeName and direction.
AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name);
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@ class UnaryVariantOpRegistry {
static std::unordered_set<string>* PersistentStringStorage();
private:
- std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
- std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
- decode_fns;
+ struct TypeIndexHash {
+ std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+ };
+
+ gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+ gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
// Map std::pair<Direction, type_name> to function.
struct PairHash {
template <typename Direction>
- std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+ std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
- ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+ ret = Hash64Combine(ret, std::get<1>(x).hash_code());
return ret;
}
- StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn, PairHash>
+ gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn, PairHash>
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@ class UnaryVariantOpRegistry {
// and references therein
template <typename Op>
struct FuncTuple {
- FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
- : op_type_(op), device_(dev), typename_(tname){};
+ FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+ : op_type_(op), device_(dev), type_index_(type_index) {}
Op op_type_;
- StringPiece device_, typename_;
+ StringPiece device_;
+ TypeIndex type_index_;
};
// friend declaration for operator==
// needed for clang
@@ -184,11 +192,11 @@ class UnaryVariantOpRegistry {
struct TupleHash {
template <typename Op>
std::size_t operator()(
- const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ const std::tuple<Op, StringPiece, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
- ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+ ret = Hash64Combine(ret, std::get<2>(x).hash_code());
return ret;
}
@@ -197,14 +205,14 @@ class UnaryVariantOpRegistry {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(x.op_type_);
ret = Hash64Combine(ret, sp_hasher_(x.device_));
- ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ ret = Hash64Combine(ret, x.type_index_.hash_code());
return ret;
}
StringPieceHasher sp_hasher_;
};
- std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@ template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
- (lhs.typename_ == rhs.typename_);
+ (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
- UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
- if (a.TypeName() != b.TypeName()) {
+ if (a.TypeId() != b.TypeId()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
- "type names: '",
+ "type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
- UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
- UnaryVariantShapeRegistration(const string& type_name,
+ UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
- type_name,
- [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+ type_index,
+ [type_index_name, shape_fn](const Variant& v,
+ TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
- "VariantShapeFn: Could not access object, type_name: ",
- type_name);
+ "VariantShapeFn: Could not access object, type_index: ",
+ type_index_name);
}
return shape_fn(*t, s);
});
@@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration {
return false;
}
Variant decoded = T();
- VariantTensorData data(*t);
- if (!decoded.Decode(data)) {
+ VariantTensorData data(std::move(*t));
+ if (!decoded.Decode(std::move(data))) {
return false;
}
- *v = std::move(decoded);
+ std::swap(decoded, *v);
return true;
});
}
@@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration {
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const LocalVariantDeviceCopyFn& device_copy_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
- direction, type_name,
- [type_name, device_copy_fn](
+ direction, type_index,
+ [type_index_name, device_copy_fn](
const Variant& from, Variant* to,
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration {
*to = T();
if (from.get<T>() == nullptr) {
return errors::Internal(
- "VariantCopyToGPUFn: Could not access object, type_name: ",
- type_name);
+ "VariantCopyToGPUFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *from.get<T>();
T* t_out = to->get<T>();
@@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration {
public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantUnaryOpFn& unary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
- op, device, type_name,
- [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
- Variant* v_out) -> Status {
+ op, device, type_index,
+ [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) -> Status {
DCHECK_NE(v_out, nullptr);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantUnaryOpFn: Could not access object, type_name: ",
- type_name);
+ "VariantUnaryOpFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration {
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantBinaryOpFn& binary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
- op, device, type_name,
- [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
- const Variant& b, Variant* out) -> Status {
+ op, device, type_index,
+ [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+ const Variant& b,
+ Variant* out) -> Status {
DCHECK_NE(out, nullptr);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'a', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+ type_index_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'b', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+ type_index_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration {
// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
- shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
- shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+ shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
- register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
+ register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)
// Register a unary decode variant function for the given type.
@@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration {
// ****** NOTE ******
// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- T, direction, type_name, device_copy_fn) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
+ device_copy_fn) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- ctr, T, direction, type_name, device_copy_fn) \
+ ctr, T, direction, type_index, device_copy_fn) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn)
+ ctr, T, direction, type_index, device_copy_fn)
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantDeviceCopyRegistration<T> \
- register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
- device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+ ctr, T, direction, type_index, device_copy_fn) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantDeviceCopyRegistration<T> \
+ register_unary_variant_op_device_copy_fn_##ctr( \
+ direction, type_index, device_copy_fn)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
- unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
- unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_index, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+ type_index, unary_op_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, unary_op_function) \
+ ctr, op, device, T, type_index, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
- binary_op_function) \
- REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, binary_op_function) \
+ ctr, op, device, T, type_index, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function)
+ ctr, op, device, T, type_index, binary_op_function)
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantBinaryOpRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_index, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
binary_op_function)
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc114..3e67e4a864 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@ namespace tensorflow {
VariantTensorData::VariantTensorData() {}
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
- FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+ FromProto(std::move(proto));
}
VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
}
}
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+ // TODO(ebrevdo): Do this lazily.
+ set_type_name(proto.type_name());
+ set_metadata(proto.metadata());
+ for (const auto& tensor : proto.tensors()) {
+ Tensor tmp;
+ if (!tmp.FromProto(tensor)) return false;
+ tensors_.push_back(tmp);
+ }
+ return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
set_type_name(proto.type_name());
set_metadata(proto.metadata());
for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) {
return proto.SerializeToString(buf);
}
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
VariantTensorDataProto proto;
const bool status = proto.ParseFromString(s);
- if (status) FromProto(proto);
+ if (status) FromProto(std::move(proto));
return status;
}
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 7500e77d43..8a240ee1e3 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class VariantTensorDataProto;
-class Tensor;
// The serialization format for Variant objects. Objects with references to
// other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@ class Tensor;
class VariantTensorData {
public:
VariantTensorData();
- VariantTensorData(const VariantTensorDataProto& proto);
+ VariantTensorData(VariantTensorDataProto proto);
~VariantTensorData();
// Name of the type of objects being serialized.
@@ -68,12 +68,14 @@ class VariantTensorData {
// Conversion to and from VariantTensorDataProto
void ToProto(VariantTensorDataProto* proto) const;
- bool FromProto(const VariantTensorDataProto& proto);
+ // This allows optimizations via std::move.
+ bool FromProto(VariantTensorDataProto proto);
+ bool FromConstProto(const VariantTensorDataProto& proto);
// Serialization via VariantTensorDataProto
string SerializeAsString() const;
bool SerializeToString(string* buf);
- bool ParseFromString(const string& s);
+ bool ParseFromString(string s);
string DebugString() const;
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47d15..08d09de7b8 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) {
struct TensorList {
void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
- bool Decode(const VariantTensorData& data) {
- vec = data.tensors_;
+ bool Decode(VariantTensorData data) {
+ vec = std::move(data.tensors_);
return true;
}
@@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) {
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(serialized);
+ y.Decode(std::move(serialized));
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) {
EXPECT_EQ(y_unknown.DebugString(),
strings::StrCat(
"Variant<type: TensorList value: ", data.DebugString(), ">"));
-
- TensorList unknown_decoded_vec;
- EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
- }
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
- }
}
TEST(VariantTest, VariantArray) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index ee10194142..7399613f6a 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
- LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
<< " NODES IN A CYCLE";
for (int64 i = 0; i < node_defs_.size(); i++) {
if (pending_count_[i] != 0) {
LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
- << "WITH PENDING COUNT = " << pending_count_[i];
+ << " WITH PENDING COUNT = " << pending_count_[i];
}
}
return errors::InvalidArgument(node_defs_.size() - processed,
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2e644fe987..f5b0105862 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d24e7e8ee4..56c8339d57 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsEnqueue(const NodeDef& n) {
- return (n.op().find("Enqueue") != std::string::npos &&
- n.op().find("EnqueueMany") == std::string::npos);
+ return (n.op().find("Enqueue") != string::npos &&
+ n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
- return (n.op().find("Dequeue") != std::string::npos &&
- n.op().find("DequeueMany") == std::string::npos);
+ return (n.op().find("Dequeue") != string::npos &&
+ n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
@@ -345,6 +345,56 @@ void VerboseLogUnknownDimensionSources(
}
}
+bool IsShapeFullyDefinedIntegerVectorOrScalar(
+ InferenceContext* ic, const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape, const DataType& dtype) {
+ if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
+ !ic->FullyDefined(tensor_as_shape) ||
+ (dtype != DT_INT32 && dtype != DT_INT64)) {
+ return false;
+ }
+ return true;
+}
+
+// Returned tensor's shape is like `shape`, and its values and dtype are from
+// `tensor_as_shape` and `dtype`.
+TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(dtype);
+ auto* shape_proto = tensor_proto.mutable_tensor_shape();
+ if (ic->Rank(shape) == 1) {
+ shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
+ }
+ // For a scalar tensor, tensor_shape field will be left empty; no dim.
+ for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
+ int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
+ if (dtype == DT_INT32) {
+ tensor_proto.add_int_val(value);
+ } else {
+ tensor_proto.add_int64_val(value);
+ }
+ }
+ return tensor_proto;
+}
+
+// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
+// and dtype = `dtype`.
+NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ NodeDef const_node;
+ const_node.set_name("const_from_shape");
+ const_node.set_op("Const");
+ auto* attr = const_node.mutable_attr();
+ (*attr)["dtype"].set_type(dtype);
+ auto* tensor = (*attr)["value"].mutable_tensor();
+ *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype);
+ return const_node;
+}
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -494,14 +544,26 @@ class SymbolicShapeRefiner {
// Replace input Placeholders with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
+ auto* ctx = GetNodeContext(function_node);
+ auto* ic = ctx->inference_context.get();
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i);
const string& node_name = NodeName(input);
NodeDef* input_node = graph_.GetNode(node_name);
- // TODO(dyoon): also use Const when output_tensors_as_shape is available.
if (IsConstant(*input_node)) {
TF_CHECK_OK(
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ // We have fully defined input_tensors_as_shapes for this input; use it
+ // as a const input to the function node.
+ NodeDef const_input_node = MakeConstNodeDefFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
+ TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
+ &grappler_function_item));
}
}
@@ -510,8 +572,8 @@ class SymbolicShapeRefiner {
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(function_node);
int output = 0;
+ ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
@@ -544,6 +606,14 @@ class SymbolicShapeRefiner {
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
+ if (outprop.has_value()) {
+ // Forward tensor value to output_tensors_as_shape.
+ Tensor tensor;
+ if (tensor.FromProto(outprop.value())) {
+ MaybeSetTensorValueToShape(ic, tensor,
+ &ctx->output_tensors_as_shapes[output]);
+ }
+ }
output++;
}
@@ -586,21 +656,9 @@ class SymbolicShapeRefiner {
if (const_values[dst_input].FromProto(
input->attr().at("value").tensor())) {
input_tensors[dst_input] = &const_values[dst_input];
- // Integer tensors of rank one can also be interpreted as a shape
- // provided all their values are >= -1.
- if (const_values[dst_input].dims() == 1 &&
- (const_values[dst_input].dtype() == DT_INT32 ||
- const_values[dst_input].dtype() == DT_INT64)) {
- ShapeHandle tensor_shape = inference_context->Vector(
- const_values[dst_input].NumElements());
- ShapeHandle shp;
- if (inference_context
- ->MakeShapeFromTensor(input_tensors[dst_input],
- tensor_shape, &shp)
- .ok()) {
- input_tensors_as_shapes[dst_input] = shp;
- }
- }
+ MaybeSetTensorValueToShape(inference_context,
+ const_values[dst_input],
+ &input_tensors_as_shapes[dst_input]);
}
} else if (IsRank(*input)) {
if (c->inference_context->RankKnown(c->inference_context->input(0))) {
@@ -968,13 +1026,25 @@ class SymbolicShapeRefiner {
: t->scalar<int64>()();
dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
} else {
- dims.push_back(ic->UnknownDim());
+ // Don't have tensor value, but use input_tensors_as_shapes, if
+ // possible.
+ const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
+ if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
+ ic->ValueKnown(ic->Dim(shape_handle, 0))) {
+ dims.push_back(ic->Dim(shape_handle, 0));
+ } else {
+ dims.push_back(ic->UnknownDim());
+ }
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
}
+ } else if (IsIdentity(node)) {
+ // Pass input_tensors_as_shapes to output_tensors_as_shapes.
+ c->output_tensors_as_shapes.resize(1);
+ c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
} else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
@@ -1079,6 +1149,46 @@ class SymbolicShapeRefiner {
}
private:
+ bool IsIntegerVector(const Tensor& tensor) {
+ if (tensor.dims() == 1 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
+ return true;
+ }
+ return false;
+ }
+
+ bool IsIntegerScalar(const Tensor& tensor) {
+ if (tensor.dims() == 0 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
+ tensor.NumElements() == 1) {
+ return true;
+ }
+ return false;
+ }
+
+ void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
+ ShapeHandle* tensors_as_shapes) {
+ // Integer tensors of rank one can also be interpreted as a shape
+ // provided all their values are >= -1.
+ if (IsIntegerVector(tensor)) {
+ ShapeHandle tensor_shape = ic->Vector(tensor.NumElements());
+ ShapeHandle shp;
+ // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1).
+ if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) {
+ *tensors_as_shapes = shp;
+ }
+ } else if (IsIntegerScalar(tensor)) {
+ // Scalar constant.
+ int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
+ : tensor.flat<int64>()(0);
+ // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+ // It's a limitation as we use ShapeHandle as a means to pass values.
+ if (value >= -1) {
+ *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
+ }
+ }
+ }
+
const GraphView& graph_;
int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
@@ -1554,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
continue;
}
+ auto* ic = ctx->inference_context.get();
+
// Fill input properties.
{
auto& input_properties = input_properties_[node.name()];
@@ -1561,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
- input_properties.resize(ctx->inference_context->num_inputs());
+ input_properties.resize(ic->num_inputs());
GraphView::InputPort input(&node, -1);
- for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->input(i),
- ctx->input_types[i],
+ for (int i = 0; i < ic->num_inputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i],
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
- if (!IsConstant(*fanin.node)) {
- continue;
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to input_properties.value.
+ if (IsConstant(*fanin.node)) {
+ const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+ *input_properties[i].mutable_value() = raw_val;
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
}
- const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
- *input_properties[i].mutable_value() = raw_val;
}
}
@@ -1584,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
- output_properties.resize(ctx->inference_context->num_outputs());
- for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->output(i),
- ctx->output_types[i],
+ output_properties.resize(ic->num_outputs());
+ for (int i = 0; i < ic->num_outputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to output_properties.value.
+ if (IsConstant(node)) {
+ const TensorProto& raw_val = node.attr().at("value").tensor();
+ *output_properties[i].mutable_value() = raw_val;
+ } else if (ctx->output_tensors_as_shapes.size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i])) {
+ *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i]);
+ }
}
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3ec68a4e59..362092a6cf 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -44,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test {
// Provision a single machine with 3 cpu cores
cluster_.reset(new SingleMachine(5 * 60, 3, 0));
TF_CHECK_OK(cluster_->Provision());
+
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape.
+ auto f = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ function_lib_.add_function()->Swap(&f);
}
void TearDown() override {
@@ -69,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test {
return s;
}
+ // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
+ // ones.
+ void ExpectTensorValues(const std::vector<int64>& expected,
+ const TensorProto& tensor_proto_to_compare) {
+ Tensor tensor;
+ EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
+ EXPECT_EQ(expected.size(), tensor.NumElements());
+ // We're interested in only integer tensors as only shapes are exported as
+ // graph properties values.
+ CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
+ if (tensor.dtype() == DT_INT32) {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
+ }
+ } else {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
+ }
+ }
+ }
+
std::unique_ptr<SingleMachine> cluster_;
+ FunctionDefLibrary function_lib_;
};
TEST_F(GraphPropertiesTest, StaticProperties) {
@@ -785,32 +831,138 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
+ Output a1 = ops::Identity(s.WithOpName("a1"), a);
+ Output b = ops::Const(s.WithOpName("b"), 99, {});
+ Output b1 = ops::Identity(s.WithOpName("b1"), b);
+ Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
+ Output c1 = ops::Identity(s.WithOpName("c1"), c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ // Check output shapes.
+ EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
+ EXPECT_EQ("int32: [2]",
+ PropToString(properties.GetOutputProperties("a1")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c1")[0]));
+
+ // Check has_value.
+ EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
+ // Note that we propagate tensro value of only 1D vector and scalar.
+ EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value());
+
+ // Check values.
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
+ ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
+ std::vector<int64> c_values;
+ for (int i = 0; i < 4 * 4 * 4; i++) {
+ c_values.push_back(1);
+ }
+ ExpectTensorValues({c_values},
+ properties.GetOutputProperties("c")[0].value());
+ ExpectTensorValues({c_values},
+ properties.GetInputProperties("c1")[0].value());
+ // No output value for c1, as it's neither 1D vector nor scalar.
+}
+
+TEST_F(GraphPropertiesTest, IdentityPassingShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 5, {2});
+ Output b = ops::Identity(s.WithOpName("b"), a);
+ Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also the value of e to figure out output
+ // shape; hence, Identity op (b) should pass a's value as
+ // output_tensors_as_shape.
+ Output d = ops::Fill(s.WithOpName("fill"), b, c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {});
+ Output b = ops::Const(s.WithOpName("b"), 2, {});
+ Output c = ops::Const(s.WithOpName("c"), 3, {});
+ Output d = ops::Const(s.WithOpName("d"), 4, {});
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
+ // from Const.
+ // If output_tensors_as_shape is not not set for those Shape ops or Pack op
+ // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
+ // hence, its output shape becomes unknown.
+ Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
+ Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
+ Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
+ Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
+ Output a = ops::Identity(s.WithOpName("a"), a0);
+ Output b = ops::Identity(s.WithOpName("b"), b0);
+ Output c = ops::Identity(s.WithOpName("c"), c0);
+ Output d = ops::Identity(s.WithOpName("d"), d0);
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
- FunctionDefLibrary library;
- // This function is simply
- // out = Fill(shape, value), but
- // Fill requires values in the shape input, not just shape of it, to infer
- // output shape; hence, func
- *library.add_function() = FunctionDefHelper::Create(
- // Name
- "MyFillFunc",
- // Inputs
- {"shape: int32", "value: float"},
- // Outputs
- {"out: float"},
- // Attrs
- {},
- // Nodes
- {
- {{"a"},
- "Fill",
- {"shape", "value"},
- {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
- },
- // Returns
- {{"out", "a:output:0"}});
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
@@ -827,13 +979,69 @@ TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyFillFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_FALSE(out_prop0.shape().unknown_rank());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(1, out_prop0.shape().dim(0).size());
- EXPECT_EQ(2, out_prop0.shape().dim(1).size());
- EXPECT_EQ(3, out_prop0.shape().dim(2).size());
- EXPECT_EQ(4, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
+ // Same to FunctionWithConstInput, but function inputs are Identity of Const,
+ // so tensor shapes, not tensor value, should be used as Const input to
+ // function.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
+ Output shape = ops::Identity(s.WithOpName("shape"), shape_);
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: int32"}, // Inputs
+ {"out: int32"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+
+ // MyFunc takes Const (shape) and passes it with Identity. Expect function
+ // output has the same shape as well as value (output_tensors_as_shape) as
+ // input Const tensor.
+ Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("int32: [2]", PropToString(out_prop0));
+ EXPECT_TRUE(out_prop0.has_value());
+ ExpectTensorValues({5, 7}, out_prop0.value());
+ ExpectTensorValues({5, 7},
+ properties.GetInputProperties("MyFunc")[0].value());
}
TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
@@ -907,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
@@ -933,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(2, out_props.size());
const OpInfo::TensorProperties& out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(128, out_prop0.shape().dim(0).size());
- EXPECT_EQ(112, out_prop0.shape().dim(1).size());
- EXPECT_EQ(112, out_prop0.shape().dim(2).size());
- EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
const OpInfo::TensorProperties& out_prop1 = out_props[1];
- EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
- EXPECT_EQ(128, out_prop1.shape().dim(0).size());
- EXPECT_EQ(112, out_prop1.shape().dim(1).size());
- EXPECT_EQ(112, out_prop1.shape().dim(2).size());
- EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
const auto in_props = properties.GetInputProperties("y0");
EXPECT_EQ(4, in_props.size());
const OpInfo::TensorProperties& in_prop0 = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
- EXPECT_EQ(1, in_prop0.shape().dim_size());
- EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+ EXPECT_EQ("float: [64]", PropToString(in_prop0));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_EQ(4, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(1, in_prop1.shape().dim(1).size());
- EXPECT_EQ(24, in_prop1.shape().dim(2).size());
- EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
const OpInfo::TensorProperties& in_prop2 = in_props[2];
- EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
- EXPECT_EQ(4, in_prop2.shape().dim_size());
- EXPECT_EQ(128, in_prop2.shape().dim(0).size());
- EXPECT_EQ(224, in_prop2.shape().dim(1).size());
- EXPECT_EQ(224, in_prop2.shape().dim(2).size());
- EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+ EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
const OpInfo::TensorProperties& in_prop3 = in_props[3];
- EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
- EXPECT_EQ(4, in_prop3.shape().dim_size());
- EXPECT_EQ(7, in_prop3.shape().dim(0).size());
- EXPECT_EQ(7, in_prop3.shape().dim(1).size());
- EXPECT_EQ(3, in_prop3.shape().dim(2).size());
- EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+ EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
}
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
@@ -1037,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
@@ -1073,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
@@ -1117,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
@@ -1166,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(3, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index aad00ce039..83434ea40f 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -127,7 +127,7 @@ static void ExtractExtraProperties(
// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("filename") != std::string::npos) {
+ op_def->input_arg(i).name().find("filename") != string::npos) {
Tensor tensor;
if (!tensor.FromProto(t)) {
continue;
@@ -153,7 +153,7 @@ static void ExtractExtraProperties(
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("handle") != std::string::npos) {
+ op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
@@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
buckets_.begin(), std::plus<uint64>());
}
-std::string TensorSizeHistogram::ToString() const {
- std::string r;
+string TensorSizeHistogram::ToString() const {
+ string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67666..5fd6717712 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@ class TensorSizeHistogram {
uint64 Max() const { return max_; }
uint64 NumElem() const { return num_elem_; }
uint64 SumElem() const { return sum_elem_; }
- std::string ToString() const;
+ string ToString() const;
protected:
const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 02a379fca8..80889afc86 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
// Helper lambda to extract port num from _Send and _Recv op name.
auto get_port_num = [](const string& name) -> int {
- if (name.find("bn_0") != std::string::npos) {
+ if (name.find("bn_0") != string::npos) {
return 0;
- } else if (name.find("bn_1") != std::string::npos) {
+ } else if (name.find("bn_1") != string::npos) {
return 1;
- } else if (name.find("bn_2") != std::string::npos) {
+ } else if (name.find("bn_2") != string::npos) {
return 2;
- } else if (name.find("bn_minus1") != std::string::npos) {
+ } else if (name.find("bn_minus1") != string::npos) {
return -1;
}
return -999;
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff877..def9198a69 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
#include <vector>
+#include "tensorflow/core/platform/env.h"
+
namespace tensorflow {
namespace grappler {
@@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) {
return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
}
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
*status = Env::Default()->FileExists(file);
return status->ok();
}
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd5359f..4b9cb0a9ad 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files,
std::vector<Status>* status = nullptr);
bool FilesExist(const std::set<string>& files);
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e78239bd43..3521669b63 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -491,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
}
}
// Queue ops modify the queue which is a side effect.
- if (node.op().find("Queue") != std::string::npos) {
+ if (node.op().find("Queue") != string::npos) {
return false;
}
return !ModifiesInputsInPlace(node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index a24004dc16..f094c151e6 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -846,3 +846,68 @@ tf_cc_test(
"//third_party/eigen3",
],
)
+
+cc_library(
+ name = "function_api_info",
+ srcs = ["function_api_info.cc"],
+ hdrs = ["function_api_info.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "function_api_info_test",
+ size = "small",
+ srcs = ["function_api_info_test.cc"],
+ deps = [
+ ":function_api_info",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "experimental_implementation_selector",
+ srcs = ["experimental_implementation_selector.cc"],
+ hdrs = ["experimental_implementation_selector.h"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ ],
+)
+
+tf_cc_test(
+ name = "experimental_implementation_selector_test",
+ size = "small",
+ srcs = ["experimental_implementation_selector_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":experimental_implementation_selector",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 39517edc06..bc838c6659 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 5a7fe19265..d4ab444036 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -273,7 +273,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- if (name.rfind("_generated") != std::string::npos &&
+ if (name.rfind("_generated") != string::npos &&
(name.rfind("_generated") == (name.size() - strlen("_generated")))) {
name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
} else {
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
new file mode 100644
index 0000000000..eeea269fb0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
+
+Status ExperimentalImplementationSelector::LoadFunctions(
+ const GraphDef& graph) {
+ lib_info_.reset(new FunctionLibraryApiInfo);
+ TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
+ NodeDef* node_def) const {
+ const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
+ if (info == nullptr) {
+ // A regular op, or a function which has no interface.
+ return Status::OK();
+ }
+
+ string task, device;
+ if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
+ return errors::Internal("Could not split device name:", node_def->device());
+ }
+ VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
+ << " = (" << task << ", " << device << ")";
+ DeviceNameUtils::ParsedName parsed_name;
+ DeviceNameUtils::ParseLocalName(device, &parsed_name);
+
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ // The current implementation is not the best, swap the op to the best one.
+ // There will be duplicates in the graph and they will be pruned by other
+ // grappler plugin since no other node is using their output as inputs.
+ // TODO(scottzhu): Update the tf.eager.defun to register functions without
+ // having to call them with input data. That will reduce the graph size and
+ // save the work for prune them.
+ node_def->set_op(best_function_name);
+ }
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::SelectImplementation(
+ GraphDef* graph) const {
+ for (int k = 0; k < graph->node_size(); ++k)
+ TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
+
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
+ return SelectImplementation(optimized_graph);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
new file mode 100644
index 0000000000..82f7473a14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// -- EXPERIMENTAL --
+// This transformation replaces function calls by the appropriate function
+// definition based on properties of the runtime system. For instance,
+// we may choose one implementation over another if we have a GPU with
+// enough memory available.
+//
+// It is a way for the programmer to specify alternative implementations
+// of the same functionality in the graph, and let TensorFlow pick the
+// most appropriate one at runtime.
+//
+// For instance, the python code might specify:
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one',
+// experimental_api_preferred_device='GPU')
+// def plus_one_gpu(x): return x + 1.0
+//
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one')
+// def plus_one_reference_implementation(x): return x + 1.0
+// input = tf.constant(2.0, dtype=tf.float32)
+//
+// z = plus_one_reference_implementation(input)
+// z = plus_one_gpu(input)
+// print(sess.run(z))
+//
+// At runtime, we will trim either `plus_one_gpu` or
+// `plus_one_reference_implementation` based on the availability of the GPU.
+//
+// Available annotations:
+// - experimental_api_implements(string): all functions mapping to the same
+// string can be interchanged. For now, all functions must have the same
+// signature and overloads are not allowed. Defuns within defuns are
+// allowed.
+// - experimental_api_preferred_device(string): sets which device is preferred.
+class ExperimentalImplementationSelector : public CustomGraphOptimizer {
+ public:
+ ExperimentalImplementationSelector() = default;
+ ~ExperimentalImplementationSelector() override = default;
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+ string name() const override {
+ return "experimental_implementation_selector";
+ }
+
+ // This call is not thread-safe.
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ // Does not take any feedback.
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ Status LoadFunctions(const GraphDef& graph);
+ Status MaybeOptimizeFunctionCall(NodeDef* node_def) const;
+
+ // Finds all call sites for functions, then replace with the appropriate
+ // implementation.
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, and
+ // 2. Via the functional interface, where the function name appears as an
+ // Attr.
+ //
+ // There may be multiple call sites for a given function. The function body
+ // may call into another function, so a function might have to be duplicated.
+ // For simplicity, we do not change function bodies. Also, we do not change
+ // gradients.
+ Status SelectImplementation(GraphDef* graph) const;
+
+ std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector);
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000000..2368e577c2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ std::unique_ptr<CustomGraphOptimizer> optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ "ExperimentalImplementationSelector");
+ ASSERT_NE(nullptr, optimizer);
+ TF_ASSERT_OK(optimizer->Init());
+
+ GraphDef output;
+ const Status status = optimizer->Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // This is a trivial graph so there is nothing to update.
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("times_two");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XAddX();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("times_two");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+ NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), 5);
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y1") {
+ // Make sure the implementation has been swapped to use the GPU version.
+ EXPECT_EQ("XAddX", node.op());
+ } else if (node.name() == "y2") {
+ // Make sure the implementation is not changed.
+ EXPECT_EQ("XTimesTwo", node.op());
+ }
+ }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XTimesFour();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+ NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ const Tensor input = test::AsScalar<float>(1.0f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", input);
+
+ const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+ test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+ test::AsScalar<float>(4.0f));
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ GrapplerItem optimized(item, std::move(output));
+ const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+ test::AsScalar<float>(2.0f));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc
new file mode 100644
index 0000000000..798e0f6fd5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+FunctionApiInfo::FunctionApiInfo() {}
+FunctionApiInfo::~FunctionApiInfo() {}
+
+Status FunctionApiInfo::Init(const FunctionDef& function_def) {
+ for (const auto& attr : function_def.attr()) {
+ if (attr.first == "experimental_api_preferred_device") {
+ preferred_device_ = attr.second.s();
+ }
+ if (attr.first == "experimental_api_implements") {
+ interface_name_ = attr.second.s();
+ }
+ }
+ if (interface_name_.empty() && !preferred_device_.empty()) {
+ return errors::InvalidArgument(
+ "Function '", function_def.signature().name(),
+ "' has a preferred device, but does not implement an interface");
+ }
+ return Status::OK();
+}
+
+const string& FunctionApiInfo::preferred_device() const {
+ return preferred_device_;
+}
+
+const string& FunctionApiInfo::interface_name() const {
+ return interface_name_;
+}
+
+FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
+FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
+
+namespace {
+bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) {
+ if (f1.ret().size() != f2.ret().size()) return false;
+ const auto& sig1 = f1.signature();
+ const auto& sig2 = f2.signature();
+ // Functions have positional semantics, so we don't check for names.
+ if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
+ for (int k = 0; k < sig1.input_arg_size(); ++k) {
+ const OpDef::ArgDef& arg1 = sig1.input_arg(k);
+ const OpDef::ArgDef& arg2 = sig2.input_arg(k);
+ if (arg1.type() != arg2.type()) return false;
+ if (arg1.type_attr() != arg2.type_attr()) return false;
+ if (arg1.number_attr() != arg2.number_attr()) return false;
+ if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
+ if (arg1.is_ref() != arg2.is_ref()) return false;
+ }
+ return true;
+}
+
+Status ValidateSignature(const string& interface_name,
+ const std::vector<const FunctionDef*>& equiv_funcs) {
+ if (equiv_funcs.size() < 2) return Status::OK();
+ for (size_t k = 1; k < equiv_funcs.size(); ++k) {
+ if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k]))
+ return errors::InvalidArgument(
+ "Functions '", equiv_funcs[0]->signature().name(), "' and '",
+ equiv_funcs[k]->signature().name(), "' both implement '",
+ interface_name, "' but their signatures do not match.");
+ }
+ return Status::OK();
+}
+
+Status ValidateSignatures(
+ const std::unordered_map<string, std::vector<const FunctionDef*>>&
+ intf_to_func) {
+ for (const auto& item : intf_to_func)
+ TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second));
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionLibraryApiInfo::Init(
+ const FunctionDefLibrary& function_library) {
+ std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func;
+ for (const auto& function : function_library.function()) {
+ std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
+ TF_RETURN_IF_ERROR(func_info->Init(function));
+ // Ignore the function if it does not implement any interface.
+ if (func_info->interface_name().empty()) continue;
+
+ const string& function_name = function.signature().name();
+ const string& interface_name = func_info->interface_name();
+ func_to_intf_[function_name] = interface_name;
+ intf_to_funcs_[interface_name].emplace_back(function_name);
+ intf_to_func[interface_name].emplace_back(&function);
+ func_info_[function_name] = std::move(func_info);
+ }
+ TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func));
+ return Status::OK();
+}
+
+void FunctionLibraryApiInfo::GetEquivalentImplementations(
+ const string& function_name, std::vector<string>* other_names) const {
+ const auto intf_it = func_to_intf_.find(function_name);
+ // The function does not implement any interface.
+ if (intf_it == func_to_intf_.end()) return;
+ CHECK(!intf_it->second.empty()) << "Function " << function_name
+ << "should at least implement 1 interface.";
+ const auto it = intf_to_funcs_.find(intf_it->second);
+ CHECK(it != intf_to_funcs_.end())
+ << "Function " << function_name << " maps to " << intf_it->second
+ << " but no reverse mapping was found";
+ CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty";
+ other_names->reserve(it->second.size() - 1);
+ for (const auto& other_name : it->second) {
+ if (other_name == function_name) continue;
+ other_names->emplace_back(other_name);
+ }
+}
+
+void FunctionLibraryApiInfo::GetBestImplementation(
+ const string& function_name, const string& device,
+ string* best_func_name) const {
+ CHECK(best_func_name != nullptr);
+ const auto func_it = func_to_intf_.find(function_name);
+ if (func_it == func_to_intf_.end()) return;
+
+ const auto it = intf_to_funcs_.find(func_it->second);
+ // No function found for the given interface.
+ if (it == intf_to_funcs_.end()) return;
+ for (const auto& func_name : it->second) {
+ const auto func_api_info = func_info_.find(func_name)->second.get();
+ if (func_api_info->preferred_device() == device) {
+ best_func_name->assign(func_name);
+ return;
+ }
+ }
+ // Didn't find a function with the match device name, choose the first one
+ // among all the available functions.
+ best_func_name->assign(it->second.front());
+}
+
+const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
+ const string& function_name) const {
+ const auto it = func_info_.find(function_name);
+ if (it == func_info_.end()) return nullptr;
+ return it->second.get();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h
new file mode 100644
index 0000000000..412687c58c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+class FunctionApiInfo {
+ public:
+ FunctionApiInfo();
+ virtual ~FunctionApiInfo();
+
+ Status Init(const FunctionDef& function_def);
+
+ const string& interface_name() const;
+ const string& preferred_device() const;
+
+ private:
+ string interface_name_;
+ string preferred_device_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
+};
+
+// A collection of information for function and the interface it implements.
+// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
+// functions could implement the same interface with different behavior based on
+// different hardware condition and limits,
+// eg F1 = math_ops.add(math_ops.add(x, x), y), or
+// F2 = math_ops.add(math_ops.matmul(x, 2), y).
+class FunctionLibraryApiInfo {
+ public:
+ FunctionLibraryApiInfo();
+ virtual ~FunctionLibraryApiInfo();
+ // Populate the internal field for the functions within the function_library.
+ Status Init(const FunctionDefLibrary& function_library);
+
+ void GetEquivalentImplementations(const string& function_name,
+ std::vector<string>* other_names) const;
+
+ void GetBestImplementation(const string& function_name, const string& device,
+ string* best_func_name) const;
+
+ const FunctionApiInfo* GetApiInfo(const string& function_name) const;
+
+ private:
+ // Map between function name to function details.
+ std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
+ // Map between function name to interface name.
+ std::unordered_map<string, string> func_to_intf_;
+ // Map between interface name to function names.
+ std::unordered_map<string, std::vector<string>> intf_to_funcs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
new file mode 100644
index 0000000000..582890d3e3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+void SetArg(const string& name, const string& type_name,
+ OpDef::ArgDef* arg_def) {
+ arg_def->set_name(name);
+ arg_def->set_type_attr(type_name);
+}
+
+typedef std::pair<string, string> ArgSpec; // name, type.
+
+void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) {
+ for (const auto& arg_spec : args_spec)
+ SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
+ SetArg("output", "float32", sig->add_output_arg());
+}
+
+void PopulateFunction(const string& name, const string& api_interface_name,
+ const string& preferred_device,
+ const std::vector<ArgSpec>& input_args,
+ FunctionDef* func_def) {
+ OpDef* sig = func_def->mutable_signature();
+ sig->set_name(name);
+
+ SetArgs(input_args, sig);
+
+ if (!api_interface_name.empty() || !preferred_device.empty()) {
+ auto* func_attr = func_def->mutable_attr();
+ if (!api_interface_name.empty())
+ (*func_attr)["experimental_api_implements"].set_s(api_interface_name);
+ if (!preferred_device.empty())
+ (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device);
+ }
+}
+
+void PopulateSampleLibrary(const bool mismatch_args,
+ FunctionDefLibrary* func_lib) {
+ const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
+ const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
+ {"in2", "int32"}};
+ PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args,
+ func_lib->add_function());
+ PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
+ mismatch_args ? func_wrong_args : func_args,
+ func_lib->add_function());
+ PopulateFunction("DoThings", "DoThings", "", func_args,
+ func_lib->add_function());
+ PopulateFunction("OneOff", "", "", func_args, func_lib->add_function());
+ PopulateFunction("AnotherOneOff", "", "", func_args,
+ func_lib->add_function());
+}
+
+bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name,
+ const std::vector<string>& expected_other) {
+ std::vector<string> other_impl;
+ lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
+ const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
+ const std::unordered_set<string> expected(expected_other.begin(),
+ expected_other.end());
+ return actual == expected;
+}
+
+bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& function_name, const string& device,
+ const string& expected_function_name) {
+ string best_function_name;
+ lib_api_info.GetBestImplementation(function_name, device,
+ &best_function_name);
+
+ return best_function_name == expected_function_name;
+}
+
+string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->interface_name();
+}
+
+string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->preferred_device();
+}
+
+TEST(FunctionApiInfoTest, ParseTags) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ TF_ASSERT_OK(lib_api_info.Init(func_lib));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
+
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
+
+ EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
+
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu"));
+
+ EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings"));
+ // TPU impl is not available, choose the first one available which is the CPU.
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu"));
+}
+
+TEST(FunctionApiInfoTest, MismatchedArguments) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ const Status ret = lib_api_info.Init(func_lib);
+ EXPECT_FALSE(ret.ok());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index a5fd33d28b..7ed4a67333 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -72,6 +72,16 @@ bool IsRunOnceOptimizer(const string& name) {
name == "loop_optimizer";
}
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (auto node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
#define MK_OPT(NAME, VALUE) \
@@ -331,10 +341,25 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ LOG(INFO) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+ VLOG(1) << "Optimized main graph.";
+
+ // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+ // passes do not handle TPU functions correctly in a variety of ways (Note
+ // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+ // ops are encapsulated away into functions). For example, TPU graphs contain
+ // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+ // passes could prune that away. Grappler passes could also cause issues
+ // around shape inference. Since the desired and existing behavior is to not
+ // optimize TPU functions with Grappler, this check preserves that.
+ if (IsTPUGraphDef(*optimized_graph)) {
+ VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+ return Status::OK();
+ }
// 2. Optimize function library
FunctionLibraryDefinition flib(OpRegistry::Global(),
@@ -398,7 +423,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- VLOG(3) << "Optimized " << optimized_funcs.size()
+ VLOG(1) << "Optimized " << optimized_funcs.size()
<< " functions: " << str_util::Join(optimized_funcs, ", ");
return Status::OK();
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4fb7aab647..ceb9f5dbf2 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -24,15 +24,16 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Compute modified strongly connected components:
+// Computes modified strongly connected components:
// All nodes that are not part of a loop are assigned the special -1 id
// All nodes that are part of at least one loop are assigned a positive
// component id: if 2 nodes v and w are reachable from one another (i.e. if they
// belong to the same scc), they'll be assigned the same id, otherwise they'll
-// be assigned distinct ids. Returns the number of distinct ids.
+// be assigned distinct ids. *num_components is set to the number of distinct
+// ids.
void StronglyConnectedComponents(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
- int* num_ids);
+ int* num_components);
// Returns the number of individual loops present in the graph, and populate the
// 'loops' argument with the collection of loops (denoted by their loop ids) a
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 972fb9efa9..94d3ab4467 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4504,6 +4504,25 @@ tf_kernel_library(
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "substr_op_test",
+ size = "small",
+ srcs = ["substr_op_test.cc"],
+ deps = [
+ ":substr_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "as_string_op",
prefix = "as_string_op",
@@ -5184,6 +5203,7 @@ filegroup(
"fifo_queue.cc",
"fifo_queue_op.cc",
"fused_batch_norm_op.cc",
+ "listdiff_op.cc",
"population_count_op.cc",
"population_count_op.h",
"winograd_transform.h",
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63949..12d9473776 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
# Description:
-# This directory contains common utilities used in boosted_trees.
+# This directory contains common quantile utilities used in boosted_trees.
package(
default_visibility = ["//tensorflow:internal"],
)
@@ -16,6 +16,7 @@ cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
+ "quantile_stream_resource.h",
"weighted_quantiles_buffer.h",
"weighted_quantiles_stream.h",
"weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 02e3655ad1..b819c6f910 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
namespace tensorflow {
@@ -28,6 +29,14 @@ namespace functor {
template <typename Device, typename T>
struct CuboidConvolution;
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
@@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> {
}
};
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor input_backward,
+ typename TTypes<T, 5>::ConstTensor filter,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward,
+ input_backward.dimension(3), // input_planes
+ input_backward.dimension(2), // input_rows
+ input_backward.dimension(1), // input_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor filter_backward,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward,
+ filter_backward.dimension(2), // kernel_planes
+ filter_backward.dimension(1), // kernel_rows
+ filter_backward.dimension(0), // kernel_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fc0a2f123f..507720c998 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@ limitations under the License.
namespace tensorflow {
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+ int dim) const {
+ return (padding == VALID)
+ ? 0
+ : std::max<int>(
+ 0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+ (filter_size(dim) - 1) * dilation(dim) +
+ 1 - input_size(dim)));
+}
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d53a..9551959463 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@ struct ConvBackpropDimensions {
// Input and output feature depth.
int64 in_depth, out_depth;
+
+ // Convenience access methods for spatial dimensions properties.
+ int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+ int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+ int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+ int64 stride(int dim) const { return spatial_dims[dim].stride; }
+ int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+ // Compute padding for the given spatial dimension.
+ int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..d26b86c712 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,10 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +224,345 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -523,6 +1132,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1144,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1159,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1189,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1220,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,10 +1281,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
@@ -675,9 +1293,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1332,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1418,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1516,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1529,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1543,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1582,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1611,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1642,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1678,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1704,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1738,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 3a1ac73f64..b3c359010d 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -675,6 +675,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "model_dataset_op",
+ srcs = ["model_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
srcs = ["dataset_ops.cc"],
deps = [
@@ -708,6 +721,7 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
+ ":model_dataset_op",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index a25f78c6f1..887b8c8365 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 221b5ad835..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ad2365b25b..31c8f5c0ea 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
@@ -358,7 +359,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
@@ -391,23 +393,51 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// will be required to plumb it through the `IteratorContext`.
auto c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ StepStats* stats = nullptr;
+ StepStatsCollector* stats_collector = nullptr;
+ std::shared_ptr<model::Node> node;
+ if (ctx->model()) {
+ node = ctx->model()->LookupNode(prefix);
+ if (node) {
+ // TODO(b/114104975): Use something light-weight here.
+ stats = new StepStats();
+ stats_collector = new StepStatsCollector(stats);
+ }
+ }
+ f_opts.stats_collector = stats_collector;
+
+ auto callback = std::bind(
+ [rets, step_container, c_mgr, frame, stats, stats_collector, node](
+ FunctionLibraryRuntime::DoneCallback done,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ if (node) {
+ int64 delta = 0;
+ stats_collector->Finalize();
+ for (auto dev_stats : stats->dev_stats()) {
+ for (auto node_stats : dev_stats.node_stats()) {
+ delta += node_stats.all_end_rel_nanos();
+ }
+ }
+ delete stats_collector;
+ delete stats;
+ node->add_processing_time(delta);
+ node->start_work();
+ }
+ done(s);
+ if (node) {
+ node->stop_work();
+ }
+ },
+ std::move(done), std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e44bc78b1c..8b420fa5db 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -104,7 +104,8 @@ class CapturedFunction {
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix);
// Returns the named list of function arguments.
const NameAttrList& func() { return func_; }
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index fe6d705eab..30c6585ba2 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -403,12 +403,12 @@ class IteratorStateVariant {
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
- bool Decode(const VariantTensorData& data) {
+ bool Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
- *tensor_data = data;
+ std::swap(*tensor_data, data);
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 27c89b3661..85e49355d3 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -204,6 +204,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+ SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -218,7 +220,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
@@ -365,7 +369,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -476,6 +481,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
new_calls.reserve(dataset()->num_parallel_calls_);
+ StartWork(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
while (true) {
{
mutex_lock l(mu_);
@@ -484,7 +492,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
batch_results_.size() > MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ == 0))) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index b87d61ee44..6657f2b2b3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -81,119 +81,167 @@ class MapDefunOp : public AsyncOpKernel {
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size;
- OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+ ComputeOptions* compute_opts = nullptr;
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- // Create a copy because every `Compute` may have different output shapes.
- auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
- auto* mu = new mutex;
-
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- }
-
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
-
- for (size_t i = 0; i < output_types().size(); ++i) {
- if (output_shapes_.at(i).IsFullyDefined()) {
- Tensor* out = nullptr;
- TensorShape output_shape;
- output_shapes_.at(i).AsTensorShape(&output_shape);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
- done);
- }
- }
-
- SetRunOptions(ctx, &opts_, false);
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes,
- std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
- mutex* mu, DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
- delete output_shapes;
- delete mu;
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
+ DoneCallback& done, const Status& status) {
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
- std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
- refcounted->Ref();
- }
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame = new MapFunctionCallFrame(
- *args, *arg_shapes, output_shapes, mu, output, this, i,
- static_cast<size_t>(batch_size));
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager;
- opts_.cancellation_manager = c_mgr;
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted, c_mgr](const Status& func_status) {
- delete call_frame;
- delete c_mgr;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
+ refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
std::vector<PartialTensorShape> output_shapes_;
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ const std::vector<Tensor> args;
+ const std::vector<TensorShape> arg_shapes;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(std::vector<Tensor> args,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(std::move(args)),
+ arg_shapes(std::move(arg_shapes)),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ int64 batch_size =
+ ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<Tensor> args;
+ std::vector<TensorShape> arg_shapes;
+ args.reserve(ctx->num_inputs());
+ arg_shapes.reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ arg_shapes.push_back(ctx->input(i).shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
+
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- std::vector<PartialTensorShape>* output_shapes,
- mutex* output_shapes_mutex, OpOutputList* output,
- OpKernel* kernel, size_t iter, size_t batch_size)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_shapes_(output_shapes),
- output_shapes_mutex_(output_shapes_mutex),
- output_(output),
- kernel_(kernel),
- iter_(iter),
- batch_size_(batch_size) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+ bool result =
+ val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
@@ -217,36 +265,34 @@ class MapDefunOp : public AsyncOpKernel {
index);
}
{ // Locking scope
- mutex_lock l(*output_shapes_mutex_);
- if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
- ", and expected output shape,",
- output_shapes_->at(index).DebugString(), ".");
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
}
- if (!output_shapes_->at(index).IsFullyDefined()) {
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
- output_shapes_->at(index) = val.shape();
+ compute_opts_->output_shapes.at(index) = val.shape();
Tensor* out = nullptr;
TensorShape actual_shape = val.shape();
- actual_shape.InsertDim(0, batch_size_);
- TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
}
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- std::vector<PartialTensorShape>* output_shapes_;
- mutex* output_shapes_mutex_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
- const size_t batch_size_;
};
};
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000000..c7f929dbc1
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,127 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ModelDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Model")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params), model_(new model::Model()) {}
+
+ ~Iterator() override { model_->OutputToFile(); }
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+ &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return input_impl_->GetNext(&ctx_with_model, out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
+ IteratorContext::Params params = ctx->params();
+ params.model = model_;
+ return params;
+ }
+
+ private:
+ mutex mu_;
+ std::shared_ptr<model::Model> model_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+ ModelDatasetOp);
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index b372d31a93..6180df5af2 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -231,10 +231,9 @@ static Status OptionalDeviceCopy(
return Status::OK();
}
-#define REGISTER_OPTIONAL_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
- OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, OptionalDeviceCopy)
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index fd0e6c4cd0..73eeafd797 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 640f1565b7..aa5e613e24 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -252,6 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -351,11 +352,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ StopWork(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ StartWork(ctx);
}
}
return errors::Cancelled(
@@ -484,10 +487,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -583,10 +586,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -601,7 +604,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -622,10 +626,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ StopWork(ctx.get());
});
bool make_new_iterator;
{
@@ -651,9 +656,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -665,7 +668,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -715,7 +720,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -764,7 +771,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
@@ -1241,6 +1250,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1256,7 +1266,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1267,7 +1279,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
}
cond_var_.notify_all();
+ StopWork(ctx);
result->notification.WaitForNotification();
+ StartWork(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1391,6 +1405,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1433,6 +1449,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
while (true) {
{
mutex_lock l(mu_);
@@ -1443,7 +1461,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
(element_in_use_[cycle_index_] ||
num_calls_ >= dataset()->num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a0cb179eb8..0795987431 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -97,31 +97,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- ParallelMapIteratorFunction map_func;
- if (use_inter_op_parallelism_) {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
- } else {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(
- [this, ctx, result](std::vector<Tensor>& input_element,
- StatusCallback& done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- },
- std::move(input_element), std::move(done)));
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
};
}
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(init_func), std::move(map_func), num_parallel_calls_);
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4ae742aaaf..0b6e587881 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
namespace tensorflow {
namespace data {
namespace {
@@ -53,6 +55,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "parallelism", num_parallel_calls_);
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -68,13 +71,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
}
cond_var_.notify_all();
+ StopWork(ctx);
result->notification.WaitForNotification();
+ StartWork(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -87,9 +94,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("invocation_results.size"),
- invocation_results_.size()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
@@ -226,6 +232,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
while (true) {
@@ -234,7 +242,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
while (!cancelled_ &&
(num_calls_ >= num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
return;
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index 533d0bd5d2..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -26,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -37,7 +44,11 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index baf448e572..52c421caee 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -12,13 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <deque>
-
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
+#include <deque>
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace data {
@@ -71,7 +74,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- auto_tuner_(params.dataset->buffer_size_) {}
+ auto_tuner_(params.dataset->buffer_size_) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
~Iterator() override {
// Signal the prefetch thread to terminate it. We will then
@@ -98,13 +105,16 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
{
mutex_lock l(mu_);
+ auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
if (cancelled_) {
@@ -113,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
}
if (!buffer_.empty()) {
- return Consume(out_tensors, end_of_sequence);
+ return Consume(out_tensors, end_of_sequence, stats_aggregator);
}
if (prefetch_thread_finished_) {
@@ -201,8 +211,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> value;
};
- Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(prefix_end_, "::buffer_utilization"),
+ {static_cast<float>(buffer_.size()) /
+ static_cast<float>(auto_tuner_.buffer_limit())});
+ }
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
Status s = buffer_.front().status;
@@ -225,10 +242,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -237,8 +254,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -246,7 +264,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
@@ -263,8 +283,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
bool end_of_sequence;
BufferElement buffer_element;
- buffer_element.status =
- input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ buffer_element.status = input_impl_->GetNext(
+ ctx.get(), &buffer_element.value, &end_of_sequence);
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(mu_);
prefetch_thread_finished_ = true;
@@ -326,6 +346,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
+ string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a74b..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
@@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel {
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
- // 8 * channels == bits per pixel
- const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+ const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847c16..6bfb5bd5bc 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index 27918b410b..8edf7d4a2c 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -59,12 +59,12 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > >,
+ const Kernel>>>>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > >,
+ const OutputBackward>>>>,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
@@ -75,7 +75,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
+ const OutputBackward>>,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
@@ -83,7 +83,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > > > > >::type
+ const Kernel>>>>>>>::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
@@ -94,12 +94,12 @@ CuboidConvolutionBackwardInput(
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
- internal::traits<Kernel>::Layout, TensorIndex> >
+ internal::traits<Kernel>::Layout, TensorIndex>>
kern(kernel);
const TensorRef<
const Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -239,8 +239,8 @@ CuboidConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // kernelPlanes, kernelRows and kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, kernelPlanes, kernelRows and kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -323,35 +323,69 @@ CuboidConvolutionBackwardInput(
*/
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
- internal::traits<OutputBackward>::Layout == ColMajor,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > > > >,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > >,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward> > > >::type
+ internal::traits<Input>::Layout == ColMajor,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>>>>>,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>>>>>>::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -362,11 +396,11 @@ CuboidConvolutionBackwardKernel(
typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions,
- internal::traits<Input>::Layout, TensorIndex> >
+ internal::traits<Input>::Layout, TensorIndex>>
in(input);
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -380,6 +414,13 @@ CuboidConvolutionBackwardKernel(
internal::traits<OutputBackward>::NumDimensions,
YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // We do not support higher dimensional backward convolutions, or convolutions
+ // without batch dimension.
+ // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+ // dimension in eigen_backward_cuboid_convolutions_test.cc.
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+
const TensorIndex inputPlanes =
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
const TensorIndex inputRows =
@@ -401,6 +442,10 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
+ // Number of batches in the input tensor.
+ const TensorIndex batch =
+ isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
// TODO(ezhulenev): Add support for inflated strides. Without inflated strides
// effective kernel planes/rows/cols are always the same as the kernel itself
// (see eigen_spatial_convolutions for details).
@@ -408,6 +453,7 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelRowsEff = kernelRows;
const TensorIndex kernelColsEff = kernelCols;
+ // Compute forward padding from input and output_backward dimensions.
const TensorIndex padPlanes = numext::maxi<Index>(
0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
const TensorIndex padRows = numext::maxi<Index>(
@@ -416,92 +462,147 @@ CuboidConvolutionBackwardKernel(
0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
const TensorIndex padding_top_z = padPlanes / 2;
- const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
const TensorIndex padding_top = padRows / 2;
- const TensorIndex padding_bottom = padRows - padding_top;
const TensorIndex padding_left = padCols / 2;
- const TensorIndex padding_right = padCols - padding_left;
- // Reshaped output_backward before contraction.
- DSizes<TensorIndex, 2> output_dims;
+ // Compute paddings for output_backward before extracting patches.
+ const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+ const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+ const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+ const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+ const auto padded_out_rows = inputRows + kernelRows - 1;
+ const auto padded_out_cols = inputCols + kernelCols - 1;
+ const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+ const auto top_pad_rows = kernelRows - 1 - padding_top;
+ const auto left_pad_cols = kernelCols - 1 - padding_left;
+ const auto bottom_pad_planes =
+ padded_out_planes - expanded_out_planes - top_pad_planes;
+ const auto bottom_pad_rows =
+ padded_out_rows - expanded_out_rows - top_pad_rows;
+ const auto right_pad_cols =
+ padded_out_cols - expanded_out_cols - left_pad_cols;
+
+ // Reorder output_backward dimensions.
+ array<TensorIndex, 5> output_backward_shuffle;
if (isColMajor) {
- output_dims[0] = kernelFilters;
- output_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- output_dims[1] *= out.dimension(i);
- }
+ // From: [out_depth, out_planes, out_rows, out_cols, batch]
+ // To: [batch, out_planes, out_rows, out_cols, out_depth]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
} else {
- output_dims[1] = kernelFilters;
- output_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- output_dims[0] *= out.dimension(i);
- }
+ // From: [batch, out_cols, out_rows, out_planes, out_depth]
+ // To: [out_depth, out_cols, out_rows, out_planes, batch]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
}
- // Reshaped extract_volume_patches(in)
- DSizes<TensorIndex, 2> pre_contract_dims;
+ // Reorder input dimensions.
+ array<TensorIndex, 5> input_shuffle;
if (isColMajor) {
- pre_contract_dims[0] =
- kernelChannels * kernelPlanes * kernelRows * kernelCols;
- pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[1] *= in.dimension(i);
- }
- eigen_assert(output_dims[1] == pre_contract_dims[1]);
+ // From: [in_depth, in_planes, in_rows, in_cols, batch]
+ // To: [in_depth, batch, in_planes, in_rows, in_cols]
+ input_shuffle = {0, 4, 1, 2, 3};
} else {
- pre_contract_dims[1] =
- kernelCols * kernelRows * kernelPlanes * kernelChannels;
- pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= in.dimension(i);
- }
- eigen_assert(output_dims[0] == pre_contract_dims[0]);
+ // From: [batch, in_cols, in_rows, in_planes, in_depth]
+ // To: [in_cols, in_rows, in_planes, batch, in_depth]
+ input_shuffle = {1, 2, 3, 0, 4};
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
+ // Input is playing the role of a "kernel" in this convolution.
+ DSizes<TensorIndex, 2> input_dims;
+ if (isColMajor) {
+ input_dims[0] = kernelChannels;
+ input_dims[1] = batch * inputPlanes * inputRows * inputCols;
+ } else {
+ input_dims[1] = kernelChannels;
+ input_dims[0] = inputCols * inputRows * inputPlanes * batch;
+ }
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+ pre_contract_dims[1] =
+ kernelPlanes * kernelRows * kernelCols * kernelFilters;
+ } else {
+ pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+ pre_contract_dims[0] =
+ kernelFilters * kernelCols * kernelRows * kernelPlanes;
+ }
+
+ // We will contract along the collapsed dimension that contains the
+ // batch, inputPlanes, inputRows and inputCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- DSizes<TensorIndex, 5> kernel_dims;
+ // Dimensions after contraction.
+ DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelRows;
- kernel_dims[4] = kernelCols;
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = kernelPlanes;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelCols;
+ post_contract_dims[4] = kernelFilters;
} else {
- kernel_dims[4] = kernelFilters;
- kernel_dims[3] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[1] = kernelRows;
- kernel_dims[0] = kernelCols;
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = kernelCols;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelPlanes;
+ post_contract_dims[4] = kernelChannels;
}
- return choose(
- Cond<internal::traits<Input>::Layout == ColMajor>(),
- output_backward.reshape(output_dims)
- .contract(input
+ // Reorder output of contraction to valid filter shape.
+ array<TensorIndex, 5> kernel_shuffle;
+ if (isColMajor) {
+ // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+ // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+ kernel_shuffle = {4, 0, 1, 2, 3};
+ } else {
+ // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+ // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+ kernel_shuffle = {1, 2, 3, 4, 0};
+ }
+
+ // Reverse kernel backprop dimensions.
+ array<TensorIndex, 5> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse = {false, false, true, true, true};
+ } else {
+ kernel_reverse = {true, true, true, false, false};
+ }
+
+ // Create convolution input (aka source of patches) from output backward
+ // tensor by shuffling dimensions.
+ const auto the_input =
+ output_backward.shuffle(output_backward_shuffle).eval();
+
+ // Create convolution kernel (aka filter) from input by shuffling and
+ // reshaping.
+ const auto the_kernel =
+ input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ the_kernel.contract(
+ the_input
.extract_volume_patches(
- kernelPlanes, kernelRows, kernelCols, stridePlanes,
- strideRows, strideCols, 1, 1, 1, padding_top_z,
- padding_bottom_z, padding_top, padding_bottom,
- padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
- contract_dims)
- .reshape(kernel_dims),
- input
- .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
- stridePlanes, strideRows, strideCols, 1, 1, 1,
- padding_top_z, padding_bottom_z, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
- .contract(output_backward.reshape(output_dims), contract_dims)
- .reshape(kernel_dims));
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols,
+ top_pad_planes, bottom_pad_planes, top_pad_rows,
+ bottom_pad_rows, left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims),
+ contract_dims),
+ the_input
+ .extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, top_pad_planes,
+ bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+ left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims)
+ .contract(the_kernel, contract_dims))
+ .reshape(post_contract_dims)
+ .shuffle(kernel_shuffle)
+ .reverse(kernel_reverse);
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 8d06107553..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9659..673ec1458b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ /*num_batches*/ 1);
Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
patch_cols);
- Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
- output_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, /*num_batches*/ 1);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(id, i, j, k) *
- output_backward(od, output_i, output_j, output_k);
+ expected += input(id, i, j, k, /*batch*/ 0) *
+ output_backward(od, output_i, output_j,
+ output_k, /*batch*/ 0);
}
}
}
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
- input_depth);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+ input_planes, input_depth);
Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
input_depth, output_depth);
- Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
- output_planes, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(k, j, i, id) *
- output_backward(output_k, output_j, output_i, od);
+ expected += input(/*batch*/ 0, k, j, i, id) *
+ output_backward(/*batch*/ 0, output_k, output_j,
+ output_i, od);
}
}
}
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f9123d..c41fbc42d3 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1362 @@ limitations under the License.
namespace Eigen {
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelDepth * kernelRows * kernelCols;
+// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionInputMapper<
+ Scalar_, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(
+ const TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension,
+ const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+ Device>& tensor,
+ const nocontract_t&, const nocontract_t&, const contract_t&,
+ const contract_t&)
+ : m_impl(tensor.impl().impl()) {
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_patch_depth = tensor.impl().dimensions()[0];
+ m_patch_planes = tensor.impl().dimensions()[1];
+ m_patch_rows = tensor.impl().dimensions()[2];
+ m_patch_cols = tensor.impl().dimensions()[3];
+ m_num_patches = tensor.impl().dimensions()[4];
+ } else {
+ const int NumDims = tensor.impl().dimensions().size();
+ m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+ }
+
+ // Strides for the output tensor.
+ // IMPORTANT: These strides are used to locate an element in a patch at a
+ // depth zero (channel), which is not quite the same as "traditional"
+ // stride.
+ m_rowStride = m_patch_planes;
+ m_colStride = m_patch_rows * m_rowStride;
+ m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+ m_otherStride = m_patchStride * m_num_patches;
+
+ m_outputPlanes = tensor.impl().outputPlanes();
+ m_outputRows = tensor.impl().outputRows();
+ m_outputCols = tensor.impl().outputCols();
+
+ m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+ m_plane_strides = tensor.impl().userPlaneStride();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_plane_strides = tensor.impl().userInPlaneStride();
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputDepth = tensor.impl().impl().dimensions()[0];
+ m_inputPlanes = tensor.impl().impl().dimensions()[1];
+ m_inputRows = tensor.impl().impl().dimensions()[2];
+ m_inputCols = tensor.impl().impl().dimensions()[3];
+ } else {
+ const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+ m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+ }
+
+ // Strides for navigating through the input tensor.
+ m_planeInputStride = m_inputDepth;
+ m_rowInputStride = m_inputDepth * m_inputPlanes;
+ m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+ m_patchInputStride =
+ m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+ m_planePaddingTop = tensor.impl().planePaddingTop();
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+ m_fastInputPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+ m_fastInputRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+ m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+ m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+ m_fastOutputPlanesRows =
+ internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+ }
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+ : m_impl(base_mapper.m_impl) {
+ m_patch_depth = base_mapper.m_patch_depth;
+ m_patch_planes = base_mapper.m_patch_planes;
+ m_patch_rows = base_mapper.m_patch_rows;
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+
+ m_rowStride = base_mapper.m_rowStride;
+ m_colStride = base_mapper.m_colStride;
+ m_patchStride = base_mapper.m_patchStride;
+ m_otherStride = base_mapper.m_otherStride;
+
+ m_planeInputStride = base_mapper.m_planeInputStride;
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+ m_otherInputStride = base_mapper.m_otherInputStride;
+
+ m_inputDepth = base_mapper.m_inputDepth;
+ m_inputPlanes = base_mapper.m_inputPlanes;
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputPlanes = base_mapper.m_outputPlanes;
+ m_outputRows = base_mapper.m_outputRows;
+ m_outputCols = base_mapper.m_outputCols;
+
+ m_plane_strides = base_mapper.m_plane_strides;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_plane_strides = base_mapper.m_in_plane_strides;
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_planePaddingTop = base_mapper.m_planePaddingTop;
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastRowStride = base_mapper.m_fastRowStride;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastOutputCols = base_mapper.m_fastOutputCols;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+ m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+ m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual
+ // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+ // gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+ return m_impl;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>;
+
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol =
+ (m_patch_col_inflate_strides == 1)
+ ? inputCol
+ : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow =
+ (m_patch_row_inflate_strides == 1)
+ ? inputRow
+ : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+ const Index origInputPlane =
+ (m_patch_plane_inflate_strides == 1)
+ ? inputPlane
+ : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+ if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+ origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+ origInputPlane >= m_inputPlanes ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+ (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+ origInputRow * m_rowInputStride +
+ origInputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset;
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+ inputRow >= m_inputRows || inputPlane < 0 ||
+ inputPlane >= m_inputPlanes) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+ return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ } else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
+ const Index patchOffsets[2] = {
+ patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+ patchOffsets[1] / m_fastColStride};
+ eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+ const Index inputCols[2] = {colIndex + colOffsets[0],
+ colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {
+ (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+ (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0],
+ rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] == inputRows[1]) {
+ const Index planeOffsets[2] = {
+ patchOffsets[0] - colOffsets[0] * m_colStride -
+ rowOffsets[0] * m_rowStride,
+ patchOffsets[1] - colOffsets[1] * m_colStride -
+ rowOffsets[1] * m_rowStride};
+ eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+ const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+ planeIndex + planeOffsets[1]};
+
+ if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex =
+ depth + inputPlanes[0] * m_planeInputStride +
+ inputRows[0] * m_rowInputStride +
+ inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ }
+
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+ inputCol >= m_inputCols || inputRow >= m_inputRows ||
+ inputPlane >= m_inputPlanes) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+ Index colIndex, Index otherIndex) const {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX
+ typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] =
+ loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ // Precompute the indices (plane, row, col, other) of the first element of
+ // the given patch index, within the output tensor of the TensorVolumePatchOp.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+ Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+ Index& otherIndex) const {
+ const int NumInputDims = array_size<
+ typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+ // Check if patchIndex might contain batch and other dimensions.
+ otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+ // Compute index of the patch within the batch (and other dimensions).
+ const Index patch3DIndex = (NumInputDims == 4)
+ ? patchIndex
+ : (patchIndex - otherIndex * m_num_patches);
+
+ otherIndex *= m_patchInputStride;
+
+ colIndex = patch3DIndex / m_fastOutputPlanesRows;
+ rowIndex =
+ (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+ planeIndex =
+ patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+ }
+
+ Index m_patch_depth; // number of channels in the patch
+ Index m_patch_planes; // number of planes in the patch
+ Index m_patch_rows; // number of rows in the patch
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract
+
+ // Strides for the output tensor.
+ Index m_rowStride;
+ Index m_colStride;
+ Index m_patchStride;
+ Index m_otherStride;
+
+ Index m_planeInputStride; // Plane stride in the input tensor
+ Index m_rowInputStride; // Row stride in the input tensor
+ Index m_colInputStride; // Col stride in the input tensor
+ Index m_patchInputStride; // Patch stride in the input tensor
+ Index m_otherInputStride;
+
+ Index m_inputDepth; // Depth of the input tensor
+ Index m_inputPlanes; // Number of planes in the input tensor
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputPlanes; // Number of output planes
+ Index m_outputRows; // Number of output rows
+ Index m_outputCols; // Number of output cols
+ Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
+
+ Index m_plane_strides; // User specified plane stride
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ // User specified plane/row/col atrous convolution strides.
+ Index m_in_plane_strides;
+ Index m_in_row_strides;
+ Index m_in_col_strides;
+
+ // User specified plane/row/col inflation strides in the image patch.
+ Index m_patch_plane_inflate_strides;
+ Index m_patch_row_inflate_strides;
+ Index m_patch_col_inflate_strides;
+
+ Index m_planePaddingTop; // Plane padding
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ // Fast representation of various divisors.
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+ internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ internal::TensorIntDivisor<Index> m_fastRowStride;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
+ internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastOutputCols;
+ internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ ParentMapper;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper),
+ m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper),
+ m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+ Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+ Index j) const {
+ return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+ j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+ loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const {
+ return m_base_mapper.m_patch_depth;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+ return m_base_mapper.m_patch_planes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const {
+ return m_base_mapper.m_patch_rows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const {
+ return m_base_mapper.m_patch_cols;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+ const Index p = m_planeIndex + plane;
+ return p < 0 || p >= m_base_mapper.m_inputPlanes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+ const Index col) const {
+ const Index p = m_planeIndex + plane;
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return p * m_base_mapper.m_planeInputStride +
+ r * m_base_mapper.m_rowInputStride +
+ c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index planeOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ const Index planeOffset = patchOffset -
+ colOffset * m_base_mapper.m_colStride -
+ rowOffset * m_base_mapper.m_rowStride;
+ return planeOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ return rowOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
+ return patchOffset;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+ getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper;
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_planeIndex;
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ...
+// A1 B1 C1 D1 E1 F1 G1 H1 ...
+// A2 B2 C2 D2 E2 F2 G2 H2 ...
+// A3 B3 C3 D3 E3 F3 G3 H3 ...
+// A4 B4 C4 D4 E4 F4 G4 H4 ...
+// A5 B5 C5 D5 E5 F5 G5 H5 ...
+// A6 B6 C6 D6 E6 F6 G6 H6 ...
+// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A8 ...
+// ...
+//
+// Packing yields row major output (A0 beside A1 in memory):
+// A0 A1 A2 A3 A4 A5 A6 A7
+// B0 B1 B2 B3 B4 B5 B6 B7
+// C0 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if ((packet_size % 4) == 0 && !non_standard_patches) {
+ const Index patch_depth = rhs.patchDepth();
+
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+ const Index patch_planes = rhs.patchPlanes();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(
+ Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+ startCol,
+ patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(
+ Eigen::divup(
+ peeled_k - c * patch_rows * patch_planes * patch_depth,
+ patch_planes * patch_depth) +
+ startRow,
+ patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+
+ const Index startPlane =
+ ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+ const Index max_planes = std::min<Index>(
+ Eigen::divup(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth, // row
+ patch_depth) +
+ startPlane,
+ patch_planes);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = startPlane; p < max_planes; ++p) {
+ eigen_assert(k < peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index startDepth =
+ ((c == startCol) && (r == startRow) && (p == startPlane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = std::min<Index>(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth - // row
+ p * patch_depth + // plane
+ startDepth,
+ patch_depth);
+ eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketFast(k);
+ kernel.packet[1] = dm1.loadPacketFast(k);
+ kernel.packet[2] = dm2.loadPacketFast(k);
+ kernel.packet[3] = dm3.loadPacketFast(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const int packet_size = 2;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ const Index patch_depth = rhs.patchDepth();
+
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+ const Index patch_planes = rhs.patchPlanes();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(
+ Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+ startCol,
+ patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(
+ Eigen::divup(
+ peeled_k - c * patch_rows * patch_planes * patch_depth,
+ patch_planes * patch_depth) +
+ startRow,
+ patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+
+ const Index startPlane =
+ ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+ const Index max_planes = std::min<Index>(
+ Eigen::divup(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth, // row
+ patch_depth) +
+ startPlane,
+ patch_planes);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = startPlane; p < max_planes; ++p) {
+ eigen_assert(k < peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index startDepth =
+ ((c == startCol) && (r == startRow) && (p == startPlane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = std::min<Index>(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth - // row
+ p * patch_depth + // plane
+ startDepth,
+ patch_depth);
+ eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketFast(k);
+ kernel0.packet[1] = dm1.loadPacketFast(k);
+ kernel1.packet[0] = dm2.loadPacketFast(k);
+ kernel1.packet[1] = dm3.loadPacketFast(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ if (!rhs.nonStandardPatches()) {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
/** CuboidConvolution
* \ingroup CXX11_NeuralNetworks_Module
*
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index cd2873bdca..7710cf93d6 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index bca1cff41c..2088c13586 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@ static Status TensorListDeviceCopy(
return Status::OK();
}
-#define REGISTER_LIST_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+ TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) {
return Status::OK();
}
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
- TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
@@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<CPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index c591226b76..a00bf700ca 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<GPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<GPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 7bb403290d..fc1c9003aa 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -127,12 +127,12 @@ class PartitionedCallOp : public AsyncOpKernel {
optimization_options.graph = &graph;
optimization_options.flib_def = overlay_lib;
optimization_options.device_set = &device_set;
- Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(
ctx,
OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
done);
+ Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
OP_REQUIRES_OK_ASYNC(
ctx,
@@ -210,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
- ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
}
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d404259b..97ddc852f7 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+ const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index 5318d8c133..e4ca89eca3 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,7 +76,15 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-
+REGISTER_KERNEL_BUILDER(
+ Name("Sum")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index ebcfb673d1..26705a8d34 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -79,7 +79,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
+ const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a9c6..cded417986 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
}
template <>
@@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+ if (!context->status().ok()) return;
const int input_dims = input.dims();
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590ae61..30cb1e0a7f 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
- GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
static void ExpectHasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) {
Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
EXPECT_FALSE(s.ok());
ExpectHasError(
- s,
- "No unary variant shape function found for Variant type_name: "
- "NO KNOWN SHAPE");
+ s, strings::StrCat(
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
}
{
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c532c9..11db72bfa3 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel {
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
- const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const Tensor& split_dim_tensor = context->input(0);
+ OP_REQUIRES(
+ context, split_dim_tensor.shape().dims() == 0,
+ errors::InvalidArgument("split_dim must be a scalar but has rank ",
+ split_dim_tensor.shape().dims()));
+ const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f61fd..add4afafc9 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- string key;
if (ctx->input_dtype(0) == DT_RESOURCE) {
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
- key = resource.name();
+ return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
} else {
Tensor Tstack_handle = ctx->mutable_input(0, false);
if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- key = strings::StrCat(container, stack_name);
- }
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) {
- return errors::Internal("No resource manager.");
- }
- auto* step_container = ctx->step_container();
- if (step_container == nullptr) {
- return errors::Internal("No step container.");
+ string key = strings::StrCat(container, stack_name);
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No resource manager.");
+ }
+ auto* step_container = ctx->step_container();
+ if (step_container == nullptr) {
+ return errors::Internal("No step container.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+ return Status::OK();
}
- TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
- return Status::OK();
}
std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..07f1d6e767 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -64,26 +68,28 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +148,16 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+ context,
+ FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +200,18 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ OP_REQUIRES(
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (", i,
+ ", ", j, ")"));
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +223,16 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000000..2e07050260
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor position(DT_INT32, TensorShape({}));
+ position.flat<int32>().setConstant(pos);
+ Tensor length(DT_INT32, TensorShape({}));
+ length.flat<int32>().setConstant(len);
+
+ TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, position))
+ .Input(test::graph::Constant(g, length))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_Substr(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
+ 256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 2ec2651c04..fe93b91eb8 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -290,7 +290,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
} else {
container = "_tensor_arrays";
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
if (StringPiece(resource.name()).substr(0, container.size()) !=
container) {
return errors::InvalidArgument("Wrong input container. ",
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
deleted file mode 100644
index 4c488066e4..0000000000
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-#include <algorithm>
-#include <iostream>
-
-namespace tensorflow {
-
-std::ostream& operator<<(std::ostream& o, StringPiece piece) {
- o.write(piece.data(), piece.size());
- return o;
-}
-
-size_t StringPiece::find(char c, size_t pos) const {
- if (pos >= size_) {
- return npos;
- }
- const char* result =
- reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
- return result != nullptr ? result - data_ : npos;
-}
-
-// Search range is [0..pos] inclusive. If pos == npos, search everything.
-size_t StringPiece::rfind(char c, size_t pos) const {
- if (size_ == 0) return npos;
- for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
- if (*p == c) {
- return p - data_;
- }
- }
- return npos;
-}
-
-StringPiece StringPiece::substr(size_t pos, size_t n) const {
- if (pos > size_) pos = size_;
- if (n > size_ - pos) n = size_ - pos;
- return StringPiece(data_ + pos, n);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 02dded42c1..e7b17c9b36 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -31,124 +31,13 @@ limitations under the License.
#include <string.h>
#include <iosfwd>
#include <string>
-#include <type_traits>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class StringPiece {
- public:
- typedef size_t size_type;
-
- // Create an empty slice.
- StringPiece() : data_(nullptr), size_(0) {}
-
- // Create a slice that refers to d[0,n-1].
- StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
-
- // Create a slice that refers to the contents of "s"
- StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
-
- // Create a slice that refers to s[0,strlen(s)-1]
- StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
-
- // Return a pointer to the beginning of the referenced data
- const char* data() const { return data_; }
-
- // Return the length (in bytes) of the referenced data
- size_t size() const { return size_; }
-
- // Return true iff the length of the referenced data is zero
- bool empty() const { return size_ == 0; }
-
- typedef const char* const_iterator;
- typedef const char* iterator;
- iterator begin() const { return data_; }
- iterator end() const { return data_ + size_; }
-
- static const size_t npos = size_type(-1);
-
- // Return the ith byte in the referenced data.
- // REQUIRES: n < size()
- char operator[](size_t n) const {
- assert(n < size());
- return data_[n];
- }
-
- // Drop the first "n" bytes from this slice.
- void remove_prefix(size_t n) {
- assert(n <= size());
- data_ += n;
- size_ -= n;
- }
-
- void remove_suffix(size_t n) {
- assert(size_ >= n);
- size_ -= n;
- }
-
- size_t find(char c, size_t pos = 0) const;
- size_t rfind(char c, size_t pos = npos) const;
-
- StringPiece substr(size_t pos, size_t n = npos) const;
-
- // Three-way comparison. Returns value:
- // < 0 iff "*this" < "b",
- // == 0 iff "*this" == "b",
- // > 0 iff "*this" > "b"
- int compare(StringPiece b) const;
-
- // Converts to various kinds of strings, including `std::basic_string`.
- template <typename S>
- explicit operator S() const {
- static_assert(
- std::is_same<char, typename S::value_type>::value,
- "Type mismatch: S must be a string with character type char.");
- static_assert(
- std::is_same<std::char_traits<char>, typename S::traits_type>::value,
- "Type mismatch: S must be a string with traits type "
- "std::char_traits<char>.");
- if (!data()) return {};
- return S(data(), size());
- }
-
- private:
- const char* data_;
- size_t size_;
-
- // Intentionally copyable
-};
-
-inline bool operator==(StringPiece x, StringPiece y) {
- return ((x.size() == y.size()) &&
- (memcmp(x.data(), y.data(), x.size()) == 0));
-}
-
-inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
-
-inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
-inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
-inline bool operator<=(StringPiece x, StringPiece y) {
- return x.compare(y) <= 0;
-}
-inline bool operator>=(StringPiece x, StringPiece y) {
- return x.compare(y) >= 0;
-}
-
-inline int StringPiece::compare(StringPiece b) const {
- const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
- int r = memcmp(data_, b.data_, min_len);
- if (r == 0) {
- if (size_ < b.size_)
- r = -1;
- else if (size_ > b.size_)
- r = +1;
- }
- return r;
-}
-
-// allow StringPiece to be logged
-extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+// Deprecated: please use absl::string_view directly.
+using StringPiece = absl::string_view;
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index c24628be57..f93ebea771 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -109,9 +109,6 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
}
Status RecordReader::ReadRecord(uint64* offset, string* record) {
- static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
- static const size_t kFooterSize = sizeof(uint32);
-
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
int64 desired_pos = static_cast<int64>(*offset);
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index c05f9e1b36..11af1366b0 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -58,6 +58,14 @@ class RecordReaderOptions {
// Note: this class is not thread safe; external synchronization required.
class RecordReader {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 6e71d23e71..2c6db2487e 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() {
}
}
-static uint32 MaskedCrc(const char* data, size_t n) {
- return crc32c::Mask(crc32c::Value(data, n));
-}
-
Status RecordWriter::WriteRecord(StringPiece data) {
if (dest_ == nullptr) {
return Status(::tensorflow::error::FAILED_PRECONDITION,
@@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) {
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
- char header[sizeof(uint64) + sizeof(uint32)];
- core::EncodeFixed64(header + 0, data.size());
- core::EncodeFixed32(header + sizeof(uint64),
- MaskedCrc(header, sizeof(uint64)));
- char footer[sizeof(uint32)];
- core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
-
+ char header[kHeaderSize];
+ char footer[kFooterSize];
+ PopulateHeader(header, data.data(), data.size());
+ PopulateFooter(footer, data.data(), data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(dest_->Append(data));
return dest_->Append(StringPiece(footer, sizeof(footer)));
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 6a2bf66d12..1212e1fafb 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
#define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@@ -47,6 +49,14 @@ class RecordWriterOptions {
class RecordWriter {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a writer that will append data to "*dest".
// "*dest" must be initially empty.
// "*dest" must remain live while this Writer is in use.
@@ -72,13 +82,35 @@ class RecordWriter {
// are invalid.
Status Close();
+ // Utility method to populate TFRecord headers. Populates record-header in
+ // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1].
+ inline static void PopulateHeader(char* header, const char* data, size_t n);
+
+ // Utility method to populate TFRecord footers. Populates record-footer in
+ // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1].
+ inline static void PopulateFooter(char* footer, const char* data, size_t n);
+
private:
WritableFile* dest_;
RecordWriterOptions options_;
+ inline static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
+void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) {
+ core::EncodeFixed64(header + 0, n);
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+}
+
+void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) {
+ core::EncodeFixed32(footer, MaskedCrc(data, n));
+}
+
} // namespace io
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index da514bd21c..946d7188d3 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -58,7 +58,7 @@ class StringDest : public WritableFile {
Status Close() override { return Status::OK(); }
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& slice) override {
+ Status Append(StringPiece slice) override {
contents_->append(slice.data(), slice.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 877ac40f1c..9cebbf40c6 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -98,7 +98,7 @@ class StringSink : public WritableFile {
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
contents_.append(data.data(), data.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 84b47c171f..cba139e6ad 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() {
return Status::OK();
}
-Status ZlibOutputBuffer::Append(const StringPiece& data) {
+Status ZlibOutputBuffer::Append(StringPiece data) {
// If there is sufficient free space in z_stream_input_ to fit data we
// add it there and return.
// If there isn't enough space we deflate the existing contents of
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 3d86d89a99..ccad2fda44 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile {
// to file when the buffer is full.
//
// To immediately write contents to file call `Flush()`.
- Status Append(const StringPiece& data) override;
+ Status Append(StringPiece data) override;
// Deflates any cached input and writes all output to file.
Status Flush() override;
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 351b6f5de3..a620f59447 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -124,6 +124,9 @@ class AlphaNum {
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
: piece_(str) {}
+ template <typename A>
+ AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str)
+ : piece_(str) {} // NOLINT(runtime/explicit)
StringPiece::size_type size() const { return piece_.size(); }
const char *data() const { return piece_.data(); }
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 36d939e061..c536b5688e 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
"Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
}
TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+ if (*channel_count < 1) {
+ return errors::InvalidArgument(
+ "Bad number of channels for WAV: Expected at least 1, but got ",
+ *channel_count);
+ }
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
uint32 bytes_per_second;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3e85..7c4184bff4 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -354,4 +358,125 @@ REGISTER_OP("BoostedTreesCenterBias")
return Status::OK();
});
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("epsilon: float")
+ .Input("num_streams: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("example_weights: float")
+ .Input("epsilon: float")
+ .Output("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle example_weights_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features), 1, &example_weights_shape));
+ for (int i = 0; i < num_features; ++i) {
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(example_weights_shape, 0),
+ &unused_dim));
+ // the columns are value, weight, min_rank, max_rank.
+ c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+ }
+ // epsilon must be a scalar.
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ // resource handle must be a scalar.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // each summary must be rank 2.
+ for (int i = 1; i < num_features + 1; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+ .Attr("generate_quantiles: bool = False")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("num_buckets: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("bucket_boundaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle unused_input;
+ // resource handle must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesBucketize")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("bucket_boundaries: num_features * float")
+ .Output("buckets: num_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ for (int i = 0; i < num_features; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(c->input(0), 0), &unused_dim));
+ }
+ // Bucketized result should have same dimension as input.
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 34e6b5560b..57c6bda98b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11360,6 +11360,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -11469,6 +11492,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -11562,6 +11608,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -11631,6 +11703,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -27192,6 +27341,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -34950,6 +35111,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 9d2b3af51d..7d9e7b2d3f 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -873,6 +873,13 @@ REGISTER_OP("IteratorGetNextAsOptional")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ModelDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
.Output("output: output_types")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c00c0030e6..190f6aaa5b 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4272,6 +4272,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -4381,6 +4404,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -4474,6 +4520,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -4543,6 +4615,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -13162,6 +13311,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -16560,6 +16721,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 79ca96d249..eff453241d 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -343,10 +343,11 @@ REGISTER_OP("DecodeCSV")
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
- if (c->Value(c->Dim(v, 0)) > 1) {
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector");
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
}
}
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index c65e66d1a8..ba594e400c 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) {
INFER_OK(op, "[1,2,?,4];?;?", "in0;in0");
INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0");
+ // Scalar defaults are ok
+ INFER_OK(op, "?;?;[]", "in0;in0");
+
// Check errors in the record_defaults inputs.
- INFER_ERROR("must be rank 1", op, "?;?;[]");
- INFER_ERROR("must be rank 1", op, "?;[];?");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?");
INFER_ERROR("Shape of a default must be", op, "?;?;[2]");
INFER_ERROR("Shape of a default must be", op, "?;[2];?");
}
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index e597a490d6..d7a13a3528 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name,
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name) {
+string MaybeAbiDemangle(const char* name) {
#if defined(_MSC_VER)
std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
std::free,
static_cast<unsigned short>(0))};
- return std::string(demangled.get() != nullptr ? demangled.get() : name);
+ return string(demangled.get() != nullptr ? demangled.get() : name);
#else
int status = 0;
std::unique_ptr<char, void (*)(void*)> res{
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 591e83b0c4..d1498a6a64 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -17,11 +17,12 @@ limitations under the License.
#define TENSORFLOW_CORE_PLATFORM_ABI_H_
#include <string>
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name);
+string MaybeAbiDemangle(const char* name);
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 8f959c018e..83228fab6f 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -371,7 +371,7 @@ class GcsWritableFile : public WritableFile {
~GcsWritableFile() override { Close().IgnoreError(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
TF_RETURN_IF_ERROR(CheckWritable());
sync_needed_ = true;
outfile_ << data;
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 92aa72be89..941ab7ad65 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -177,7 +177,7 @@ class RetryingWritableFile : public WritableFile {
Close().IgnoreError();
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
[this, &data]() { return base_file_->Append(data); },
initial_delay_microseconds_);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index ec2c470db7..5910fef1d2 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile {
class MockWritableFile : public WritableFile {
public:
explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return calls_.ConsumeNextCall("Append");
}
Status Close() override { return calls_.ConsumeNextCall("Close"); }
diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h
new file mode 100644
index 0000000000..7c5c6655be
--- /dev/null
+++ b/tensorflow/core/platform/cord.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_CORD_H_
+
+// Include appropriate platform-dependent implementations
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/cord.h"
+#else
+#include "tensorflow/core/platform/default/cord.h"
+#endif
+
+#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
new file mode 100644
index 0000000000..1ab682182c
--- /dev/null
+++ b/tensorflow/core/platform/default/cord.h
@@ -0,0 +1,24 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+
+class Cord;
+namespace absl {
+using ::Cord;
+} // namespace absl
+
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 305a9a682f..2e32abdffb 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) {
// Write something to the temporary file.
std::unique_ptr<WritableFile> file_to_write;
TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
+#if defined(PLATFORM_GOOGLE)
+ TF_CHECK_OK(file_to_write->Append("Nu"));
+ TF_CHECK_OK(file_to_write->Append(absl::Cord("ll")));
+#else
+ // TODO(ebrevdo): Remove this version.
TF_CHECK_OK(file_to_write->Append("Null"));
+#endif
TF_CHECK_OK(file_to_write->Close());
TF_CHECK_OK(env->FileExists(filename));
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 077b1d79cf..30059dc02e 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/file_statistics.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/platform.h"
@@ -252,7 +253,12 @@ class WritableFile {
virtual ~WritableFile();
/// \brief Append 'data' to the file.
- virtual Status Append(const StringPiece& data) = 0;
+ virtual Status Append(StringPiece data) = 0;
+
+ // \brief Append 'data' to the file.
+ virtual Status Append(const absl::Cord& cord) {
+ return errors::Unimplemented("Append(absl::Cord) is not implemented");
+ }
/// \brief Close the file.
///
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 8cdb08f51b..eb35531e9f 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (hdfs_->hdfsWrite(fs_, file_, data.data(),
static_cast<tSize>(data.size())) == -1) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 47bfa020ce..c7afab9583 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index ce0f6cd741..e0b8e37745 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile {
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (!outfile_) {
return errors::FailedPrecondition(
"The internal temporary file is not writable.");
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 9079a5ccaa..6cf79634d7 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
DWORD bytes_written = 0;
DWORD data_size = static_cast<DWORD>(data.size());
BOOL write_result =
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933051..546b0a833c 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@ namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
- int64 N = iter_->ix_.dim_size(0);
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
+ const int64 N = ix_t.dimension(0);
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
@@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
std::vector<int64> Group::group() const {
std::vector<int64> g;
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
@@ -62,8 +62,8 @@ std::vector<int64> Group::group() const {
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
- return TTypes<int64>::UnalignedConstMatrix(
- &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+ return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+ next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6116..14610c61d9 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@ class GroupIterable {
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix),
+ ix_matrix_(ix_.matrix<int64>()),
vals_(vals),
dims_(dims),
group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@ class GroupIterable {
private:
friend class Group;
- Tensor ix_;
+ const Tensor ix_;
+ const TTypes<int64>::ConstMatrix ix_matrix_;
Tensor vals_;
const int dims_;
const gtl::InlinedVector<int64, 8> group_dims_;
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index dac9b7ab82..82bc3ffda9 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -121,10 +121,6 @@ the Android NDK and SDK must be installed on your system.
2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
current recommended version is 14b, which may be found
[here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
-
- * NDK 16, the revision released in November 2017, is **incompatible** with
- Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918).
-
3. The Android SDK and build tools may be obtained
[here](https://developer.android.com/tools/revisions/build-tools.html), or
alternatively as part of [Android
@@ -132,10 +128,6 @@ the Android NDK and SDK must be installed on your system.
23 is required to build the TF Android demo (though it will run on API >= 21
devices).
- - The Android Studio SDK Manager's NDK installer will install the latest
- revision of the NDK, which is **incompatible** with Bazel. You'll need
- to download an older version manually, as (2) suggests.
-
##### Edit WORKSPACE
NOTE: As long as you have the SDK and NDK installed, the `./configure` script
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD
index 3630b41fc8..3630b41fc8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/examples/autograph/integration_tests/BUILD
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
index 04a968be10..69e5936832 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -20,21 +20,18 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph as ag
-from tensorflow.python.util import tf_inspect
+from tensorflow.python import autograph as ag
class ErrorsTest(tf.test.TestCase):
def test_graph_construction_error_rewriting_call_tree(self):
- def innermost(x):
- if x > 0:
- return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
- return tf.zeros((2, 3))
+ def test_fn():
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
def inner_caller():
- return innermost(1.0)
+ return test_fn()
def caller():
return inner_caller()
@@ -45,23 +42,21 @@ class ErrorsTest(tf.test.TestCase):
expected = error.exception
custom_traceback = expected.custom_traceback
found_correct_filename = False
- num_innermost_names = 0
+ num_test_fn_names = 0
num_inner_caller_names = 0
num_caller_names = 0
- ag_output_filename = tf_inspect.getsourcefile(graph)
for frame in custom_traceback:
filename, _, fn_name, _ = frame
- self.assertFalse('control_flow_ops.py' in filename)
- self.assertFalse(ag_output_filename in filename)
+ self.assertFalse('/tmp/' in filename)
found_correct_filename |= __file__ in filename
self.assertNotEqual('tf__test_fn', fn_name)
- num_innermost_names += int('innermost' == fn_name)
+ num_test_fn_names += int('test_fn' == fn_name)
self.assertNotEqual('tf__inner_caller', fn_name)
num_inner_caller_names += int('inner_caller' == fn_name)
self.assertNotEqual('tf__caller', fn_name)
num_caller_names += int('caller' == fn_name)
self.assertTrue(found_correct_filename)
- self.assertEqual(num_innermost_names, 1)
+ self.assertEqual(num_test_fn_names, 1)
self.assertEqual(num_inner_caller_names, 1)
self.assertEqual(num_caller_names, 1)
@@ -97,7 +92,7 @@ class ErrorsTest(tf.test.TestCase):
compiled_fn = ag.to_graph(test_fn)
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
@@ -106,19 +101,14 @@ class ErrorsTest(tf.test.TestCase):
found_correct_filename = False
num_test_fn_frames = 0
num_g_frames = 0
- ag_output_filename = tf_inspect.getsourcefile(compiled_fn)
for frame in custom_traceback:
filename, _, fn_name, source_code = frame
- self.assertFalse(ag_output_filename in filename)
- self.assertFalse('control_flow_ops.py' in filename)
+ self.assertFalse('/tmp/' in filename)
+ self.assertFalse('control_flow.py' in filename)
self.assertFalse('ag__.' in fn_name)
- self.assertFalse('tf__g' in fn_name)
- self.assertFalse('tf__test_fn' in fn_name)
found_correct_filename |= __file__ in filename
num_test_fn_frames += int('test_fn' == fn_name and
'return g(x, 10)' in source_code)
- # This makes sure that the code is correctly rewritten from "x_1 //= 0" to
- # "x //= 0".
num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
self.assertTrue(found_correct_filename)
self.assertEqual(num_test_fn_frames, 1)
@@ -144,7 +134,7 @@ class ErrorsTest(tf.test.TestCase):
# frame with "g" as the function name but because we don't yet add
# try/except blocks to inner functions the name is "tf__g".
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
index 7e7ef5a3e2..dca7c07b47 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph
+from tensorflow.python import autograph
class MinimalKeras(tf.keras.Model):
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
index 904246afb7..917f5ff9d8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
+++ b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph as ag
+from tensorflow.python import autograph as ag
def list_used_as_tuple():
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index c8de6c2152..0c7ca9bc01 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
def testCreateInferenceGraphWithMfcc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
@@ -44,7 +44,7 @@ class FreezeTest(test.TestCase):
self.assertEqual(1, ops.count('Mfcc'))
def testCreateInferenceGraphWithoutMfcc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
@@ -63,7 +63,7 @@ class FreezeTest(test.TestCase):
self.assertEqual(0, ops.count('Mfcc'))
def testFeatureBinCount(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py
index 2e551be9a2..aa4e807779 100644
--- a/tensorflow/examples/speech_commands/input_data_test.py
+++ b/tensorflow/examples/speech_commands/input_data_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class InputDataTest(test.TestCase):
def _getWavData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
@@ -75,7 +75,7 @@ class InputDataTest(test.TestCase):
self._saveTestWavFile(file_path, wav_data)
model_settings = models.prepare_model_settings(
4, 16000, 1000, window_length_ms, 20, 40, preprocess)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
audio_processor = input_data.AudioProcessor(
"", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_data(
diff --git a/tensorflow/examples/speech_commands/label_wav_test.py b/tensorflow/examples/speech_commands/label_wav_test.py
index 80ca774706..f0af2a4798 100644
--- a/tensorflow/examples/speech_commands/label_wav_test.py
+++ b/tensorflow/examples/speech_commands/label_wav_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class LabelWavTest(test.TestCase):
def _getWavData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 0c373967ed..04478c0962 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -49,7 +49,7 @@ class ModelsTest(test.TestCase):
def testCreateModelConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
model_settings, "conv", True)
@@ -60,7 +60,7 @@ class ModelsTest(test.TestCase):
def testCreateModelConvInference(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
False)
@@ -69,7 +69,7 @@ class ModelsTest(test.TestCase):
def testCreateModelLowLatencyConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "low_latency_conv", True)
@@ -80,7 +80,7 @@ class ModelsTest(test.TestCase):
def testCreateModelFullyConnectedTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "single_fc", True)
@@ -91,7 +91,7 @@ class ModelsTest(test.TestCase):
def testCreateModelBadArchitecture(self):
model_settings = self._modelSettings()
- with self.test_session():
+ with self.cached_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
models.create_model(fingerprint_input, model_settings,
@@ -100,7 +100,7 @@ class ModelsTest(test.TestCase):
def testCreateModelTinyConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "tiny_conv", True)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index e755c37039..322b35dd91 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3456,6 +3456,36 @@ func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output)
return op.Output(0), op.Output(1)
}
+// Debugging/model interpretability outputs for each example.
+//
+// It traverses all the trees and computes debug metrics for individual examples,
+// such as getting split feature ids and logits after each split along the decision
+// path used to compute directional feature contributions.
+//
+// Arguments:
+//
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
+// examples_debug_outputs_serialized.
+//
+// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
+func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesExampleDebugOutputs",
+ Input: []tf.Input{
+ tree_ensemble_handle, tf.OutputList(bucketized_features),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along sparse segments of a tensor.
//
// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
@@ -13892,34 +13922,6 @@ func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, label
return op.Output(0), op.Output(1)
}
-// Fast Fourier transform.
-//
-// Computes the 1-dimensional discrete Fourier transform over the inner-most
-// dimension of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most
-// dimension of `input` is replaced with its 1D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.fft
-// @end_compatibility
-func FFT(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "FFT",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a serialized tensorflow.TensorProto proto into a Tensor.
//
// Arguments:
@@ -26636,36 +26638,6 @@ func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset t
return op.Output(0)
}
-// Debugging/model interpretability outputs for each example.
-//
-// It traverses all the trees and computes debug metrics for individual examples,
-// such as getting split feature ids and logits after each split along the decision
-// path used to compute directional feature contributions.
-//
-// Arguments:
-//
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
-// examples_debug_outputs_serialized.
-//
-// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
-func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesExampleDebugOutputs",
- Input: []tf.Input{
- tree_ensemble_handle, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds a value to the current value of a variable.
//
// Any ReadVariableOp with a control dependency on this op is guaranteed to
@@ -28153,6 +28125,34 @@ func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []
return op.Output(0)
}
+// Fast Fourier transform.
+//
+// Computes the 1-dimensional discrete Fourier transform over the inner-most
+// dimension of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most
+// dimension of `input` is replaced with its 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.fft
+// @end_compatibility
+func FFT(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "FFT",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 19729813a1..2dc2808152 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3058,6 +3058,7 @@ cuda_py_test(
":functional_ops",
":gradients",
":layers",
+ ":list_ops",
":math_grad",
":math_ops",
":nn_grad",
diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD
new file mode 100644
index 0000000000..3289b447e7
--- /dev/null
+++ b/tensorflow/python/autograph/BUILD
@@ -0,0 +1,31 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "autograph",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/impl",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
index 06fb7b03d5..1ded5ba5f6 100644
--- a/tensorflow/contrib/autograph/CONTRIBUTING.md
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -2,6 +2,15 @@
We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
+### Note to active contributors
+
+In preparation for TF 2.0, we moved the code base of AutoGraph from
+`tensorflow/contrib/autograph` to `tensorflow/python/autograph`. The move
+does not impact functionality, and AutoGraph will remain accessible under
+`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
+
+When
+
## TensorFlow Code of Conduct
Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/python/autograph/LIMITATIONS.md
index d8b1cb7616..d8b1cb7616 100644
--- a/tensorflow/contrib/autograph/LIMITATIONS.md
+++ b/tensorflow/python/autograph/LIMITATIONS.md
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
new file mode 100644
index 0000000000..cc54da4daa
--- /dev/null
+++ b/tensorflow/python/autograph/README.md
@@ -0,0 +1,143 @@
+# AutoGraph
+
+IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+
+AutoGraph is a Python to TensorFlow compiler.
+
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
+
+For example, this Python function:
+
+```
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+would be converted to this:
+
+```
+def graph_mode_f(x):
+ with tf.name_scope('f'):
+
+ def if_true():
+ with tf.name_scope('if_true'):
+ x_1, = x,
+ x_1 = tf.negative(x_1)
+ return x_1,
+
+ def if_false():
+ with tf.name_scope('if_false'):
+ x_1, = x,
+ return x_1,
+ x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
+ return x
+```
+
+so you can use it like an op:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1.0)
+
+ converted_f = autograph.to_graph(f)
+ y = converted_f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+# Getting started
+
+Use AutoGraph in one of the following ways, described below:
+
+ 1. Annotations (simpler)
+ 2. Functional API (more flexible)
+
+To get started, install the latest nightly TensorFlow build:
+
+```shell
+pip install -U tf-nightly
+```
+
+Then import the `autograph` module from `tf.contrib`:
+
+```
+from tensorflow.contrib import autograph as ag
+```
+
+### Related links
+
+Articles:
+
+ * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
+
+Interactive notebooks:
+
+ * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
+ * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
+ * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
+ * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
+ * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
+ * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
+ * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
+
+## Using with annotations
+
+Annotating a function or class with `@convert` converts it in place:
+
+```
+@ag.convert()
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+... so that it always outputs TensorFlow code:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1)
+
+ y = f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+## Using the functional API
+
+The functional API allows you to convert an existing function, class or object after it was defined:
+
+```
+converted_f = ag.to_graph(f)
+
+print(converted_f(tf.constant(-1)))
+# Output: Tensor
+
+print(f(-1))
+# Output: 1
+```
+
+You can use the functional API to inspect the generated code as well:
+
+```
+print(ag.to_code(f))
+# Output: <Python and TensorFlow code>
+```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/python/autograph/STYLE_GUIDE.md
index 7e6b0cc27d..7e6b0cc27d 100644
--- a/tensorflow/contrib/autograph/STYLE_GUIDE.md
+++ b/tensorflow/python/autograph/STYLE_GUIDE.md
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
new file mode 100644
index 0000000000..c3448e6e58
--- /dev/null
+++ b/tensorflow/python/autograph/__init__.py
@@ -0,0 +1,68 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Autograph compiles Python code into equivalent TensorFlow code.
+
+Equivalent here means that they have the same effect when executed.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Bring only the relevant symbols to the top level.
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core.errors import GraphConstructionError
+from tensorflow.python.autograph.core.errors import TfRuntimeError
+from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import RunMode
+from tensorflow.python.autograph.impl.api import convert
+from tensorflow.python.autograph.impl.api import converted_call
+from tensorflow.python.autograph.impl.api import do_not_convert
+from tensorflow.python.autograph.impl.api import to_code
+from tensorflow.python.autograph.impl.api import to_graph
+from tensorflow.python.autograph.lang.directives import set_element_type
+from tensorflow.python.autograph.lang.directives import set_loop_options
+from tensorflow.python.autograph.lang.special_functions import stack
+from tensorflow.python.autograph.lang.special_functions import tensor_list
+from tensorflow.python.autograph.pyct.transformer import AutographParseError
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ # Main API
+ 'RunMode',
+ 'convert',
+ 'converted_call',
+ 'do_not_convert',
+ 'to_code',
+ 'to_graph',
+ # Overloaded operators
+ 'operators',
+ # Errors
+ 'improved_errors',
+ 'GraphConstructionError',
+ 'TfRuntimeError',
+ # Python language "extensions"
+ 'set_element_type',
+ 'set_loop_options',
+ 'stack',
+ 'tensor_list',
+ # Exceptions
+ 'AutographParseError',
+ # Utilities: to be removed
+ 'utils',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 2d2ab7040a..7b029de8ed 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -38,11 +38,11 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/core",
- "//tensorflow/contrib/autograph/lang",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
"@gast_archive//:gast",
],
)
@@ -54,8 +54,8 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -65,8 +65,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -77,8 +77,8 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -90,9 +90,9 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/impl",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/impl",
],
)
@@ -102,8 +102,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -113,8 +113,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -124,8 +124,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -139,8 +139,8 @@ py_test(
],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -150,9 +150,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/lang",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/lang",
],
)
@@ -161,9 +161,9 @@ py_test(
srcs = ["name_scopes_test.py"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -173,8 +173,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -184,8 +184,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -195,8 +195,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -207,8 +207,8 @@ py_test(
tags = ["notsan"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -218,9 +218,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -230,9 +230,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -242,8 +242,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/python/autograph/converters/__init__.py
index 6325ac78dc..6325ac78dc 100644
--- a/tensorflow/contrib/autograph/converters/__init__.py
+++ b/tensorflow/python/autograph/converters/__init__.py
diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py
index af2f20f267..56a97534c4 100644
--- a/tensorflow/contrib/autograph/converters/asserts.py
+++ b/tensorflow/python/autograph/converters/asserts.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
class AssertTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index 38faba45df..01282f9e62 100644
--- a/tensorflow/contrib/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
index 180779670d..bd6b0b248c 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class _Break(object):
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
index fcae7d68c0..39406a969d 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
index 29dce13999..b8b268d8ce 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -20,10 +20,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.operators import py_builtins
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
class BuiltinFunctionTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
index 3e3a04f38b..c87c304cdb 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 2d1bed3367..6a606c450d 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -26,12 +26,12 @@ from collections import namedtuple
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index ca4d1f2932..0e50f42c6a 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
index 63f649dfdf..40728f555d 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class _FunctionDefs(object):
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py b/tensorflow/python/autograph/converters/conditional_expressions_test.py
index 95a3108741..dd1f8d485c 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/python/autograph/converters/continue_statements.py
index 0476e97c15..584cdc1efd 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/python/autograph/converters/continue_statements.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py
index 37c15211b4..d6aaa50443 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/python/autograph/converters/continue_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 3530fbb2ec..416a60d2ee 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -20,12 +20,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis import annos
class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 1d04ba3ba6..cfa0ea920c 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/python/autograph/converters/decorators.py
index 3471bd11d6..724f0fe5ed 100644
--- a/tensorflow/contrib/autograph/converters/decorators.py
+++ b/tensorflow/python/autograph/converters/decorators.py
@@ -24,8 +24,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/python/autograph/converters/decorators_test.py
index 095abc5edc..fb31c8d583 100644
--- a/tensorflow/contrib/autograph/converters/decorators_test.py
+++ b/tensorflow/python/autograph/converters/decorators_test.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python import autograph
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
@@ -136,6 +138,12 @@ class DecoratorsTest(converter_testing.TestCase):
return inner_fn(a)
+ # Work around TensorFlow's symbol suppression mechanism that causes core to
+ # be invisible in the generated code.
+ core_mod = imp.new_module('core')
+ core_mod.converter_testing = converter_testing
+ autograph.core = core_mod
+
# 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
self.assertEqual(14, test_fn(1))
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py
index 77f625bac7..fc646348ef 100644
--- a/tensorflow/contrib/autograph/converters/directives.py
+++ b/tensorflow/python/autograph/converters/directives.py
@@ -25,9 +25,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
ENCLOSING_LOOP = 'enclosing_loop'
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py
index a2d083b891..570fb8e379 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/python/autograph/converters/directives_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import directives as directives_converter
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core.converter import AgAnno
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import directives as directives_converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core.converter import AgAnno
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/python/autograph/converters/error_handlers.py
index 1936821394..de46c0c830 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers.py
+++ b/tensorflow/python/autograph/converters/error_handlers.py
@@ -22,9 +22,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
class ErrorRewritingTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/python/autograph/converters/error_handlers_test.py
index 5d61b220af..676ff9e02b 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/python/autograph/converters/error_handlers_test.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py
index ecf4628816..5be6cb9a98 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions.py
@@ -32,8 +32,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
# TODO(mdan): This should covert directly to operator calls.
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py
index 59b5ce9ca0..1e66139af6 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import list_comprehensions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import list_comprehensions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/python/autograph/converters/lists.py
index a02fc827b8..8180801753 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/python/autograph/converters/lists.py
@@ -32,12 +32,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
index c5e2dcf75e..f6da845fcc 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.lang import special_functions
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.lang import special_functions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
index 41c3424fa3..8c4d53f9a8 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -23,10 +23,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
@@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'ag__.utils.dynamic_is',
- gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base):
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+ def _has_matching_func(self, operator):
+ op_type = type(operator)
+ return op_type in self.op_mapping
+
def _matching_func(self, operator):
op_type = type(operator)
- mapped_op = self.op_mapping.get(op_type)
- if not mapped_op:
- raise NotImplementedError('operator %s is not yet supported' % op_type)
- return mapped_op
+ return self.op_mapping[op_type]
def _as_function(self, func_name, args):
template = """
@@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
+
+ if not all(self._has_matching_func(op) for op in node.ops):
+ if len(node.ops) == 1:
+ # Basic expressions are safe to leave as they are.
+ return node
+ else:
+ raise NotImplementedError(
+ 'compound expression with at least one unsupported '
+ 'operator: {}'.format(node.ops))
+
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
index 409a73afba..b78b4d3a6a 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
- def test_ag_utils_lookup(self):
+ def test_unsupported_ops(self):
def test_fn(a, b):
- return a is b or a is not b
+ return a in b
- with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
- ) as result:
- with self.cached_session() as sess:
- self.assertTrue(sess.run(result.test_fn(True, False)))
+ with self.converted(test_fn, logical_expressions, {}) as result:
+ self.assertTrue(result.test_fn('a', ('a',)))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/name_scopes.py
index dd6c6bf960..a9c55ccff0 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/name_scopes.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
class FunctionNameScopeTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/name_scopes_test.py
index a329b0db70..73933c1c4f 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/name_scopes_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index a351cd81b8..62da045d6a 100644
--- a/tensorflow/contrib/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Move this logic into transformer_base.
diff --git a/tensorflow/contrib/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 3c7c8c8a25..01dd03da0b 100644
--- a/tensorflow/contrib/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
index b808604f0a..6e48e57bde 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -36,12 +36,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
index 5fe5114d4b..cef3199169 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/python/autograph/converters/slices.py
index c527f98613..11cea6de5b 100644
--- a/tensorflow/contrib/autograph/converters/slices.py
+++ b/tensorflow/python/autograph/converters/slices.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import templates
class SliceTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py
index d74b2e025e..e190a7cfe8 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/python/autograph/converters/slices_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 1873045a92..85fecf084d 100644
--- a/tensorflow/contrib/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -25,9 +25,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
],
)
@@ -65,10 +65,10 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
":core",
- "//tensorflow/contrib/autograph/operators",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
index 878bb7e12f..4fa8489af5 100644
--- a/tensorflow/contrib/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import utils
+from tensorflow.python.autograph import utils
PYTHON_LITERALS = {
@@ -36,7 +36,7 @@ DEFAULT_UNCOMPILED_MODULES = set((
# have well-known names. Not referring to the module directly to avoid
# circular imports.
(
- utils.__name__[:-len('.contrib.autograph.utils')],),
+ utils.__name__[:-len('.python.autograph.utils')],),
))
NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 83a80c1f52..7b3905fdee 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -67,19 +67,19 @@ import collections
from enum import Enum
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import naming
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import naming
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
# TODO(mdan): These contexts can be refactored into first class objects.
# For example, we could define Program and Entity abstractions that hold on
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 5ee2c3fffd..0a0c6f9002 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -24,15 +24,15 @@ import sys
import six
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import pretty_printer
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
index 5a57d57e7d..0750353423 100644
--- a/tensorflow/contrib/autograph/core/errors.py
+++ b/tensorflow/python/autograph/core/errors.py
@@ -31,7 +31,7 @@ import logging
import sys
import traceback
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.framework import errors_impl
# TODO(mdan): Add a superclass common to all errors.
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
index 404c1f5456..0444ed7eab 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors as tf_errors
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py
index b1d3f76be7..aecc9e33ca 100644
--- a/tensorflow/contrib/autograph/core/naming.py
+++ b/tensorflow/python/autograph/core/naming.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import qual_names
class Namer(object):
diff --git a/tensorflow/contrib/autograph/core/naming_test.py b/tensorflow/python/autograph/core/naming_test.py
index d2bebd0478..2db98836d1 100644
--- a/tensorflow/contrib/autograph/core/naming_test.py
+++ b/tensorflow/python/autograph/core/naming_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import naming
+from tensorflow.python.autograph.core import naming
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
index bcbb920cc5..c2427f5f4f 100644
--- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
+++ b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
@@ -4,7 +4,7 @@ The `py_func` op requires specifying a
[data type](https://www.tensorflow.org/guide/tensors#data_types).
When wrapping a function with `py_func`, for instance using
-`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two
+`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
options to specify the returned data type:
* explicitly, with a specified `tf.DType` value
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD
index a5438592c3..bef62a6403 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/python/autograph/impl/BUILD
@@ -23,14 +23,14 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/converters",
- "//tensorflow/contrib/autograph/core",
- "//tensorflow/contrib/autograph/operators",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/converters",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
"@six_archive//:six",
],
@@ -43,8 +43,8 @@ py_test(
tags = ["no_windows"],
deps = [
":impl",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/utils",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 8b38d5d080..669d36bd28 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -22,17 +22,13 @@ from functools import wraps
from enum import Enum
-# pylint:disable=g-bad-import-order
-import six
-# pylint:enable=g-bad-import-order
-
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import conversion
-from tensorflow.contrib.autograph.operators import py_builtins
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import conversion
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -257,7 +253,7 @@ def to_graph(e,
arg_types)
nodes = []
- for dep in reversed(program_ctx.dependency_cache.values()):
+ for dep in reversed(tuple(program_ctx.dependency_cache.values())):
nodes.extend(dep)
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
@@ -327,6 +323,6 @@ def to_code(e,
code = '\n'.join(
compiler.ast_to_source(dep, indentation)
- for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
+ for dep in reversed(tuple(program_ctx.dependency_cache.values())))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index a4c6fed265..54e12f0223 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index fc8a976d3f..928ff9e7ea 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -22,34 +22,34 @@ import imp
import gast
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import directives
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.converters import directives
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
index 86432573a7..07d0f75129 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.impl import conversion
from tensorflow.python.framework import constant_op
from tensorflow.python.keras.engine import training
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD
index 77a2184e22..462349cc10 100644
--- a/tensorflow/contrib/autograph/lang/BUILD
+++ b/tensorflow/python/autograph/lang/BUILD
@@ -25,7 +25,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/operators",
+ "//tensorflow/python/autograph/operators",
],
)
diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/python/autograph/lang/directives.py
index aabe5d9939..aabe5d9939 100644
--- a/tensorflow/contrib/autograph/lang/directives.py
+++ b/tensorflow/python/autograph/lang/directives.py
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
index 6149cbbd6c..e4838d1b6d 100644
--- a/tensorflow/contrib/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -23,7 +23,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
def tensor_list(elements,
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index db492cc5c6..1f1cec18f7 100644
--- a/tensorflow/contrib/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index 29759bad79..a116611b64 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -28,7 +28,6 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
@@ -38,6 +37,7 @@ py_library(
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variables",
+ "//tensorflow/python/autograph/utils",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -66,6 +66,7 @@ py_test(
name = "py_builtins_test",
srcs = ["py_builtins_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":operators",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index c4fbc260a2..0d3b44b6c4 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -37,19 +37,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators.control_flow import for_stmt
-from tensorflow.contrib.autograph.operators.control_flow import while_stmt
-from tensorflow.contrib.autograph.operators.data_structures import list_append
-from tensorflow.contrib.autograph.operators.data_structures import list_pop
-from tensorflow.contrib.autograph.operators.data_structures import list_stack
-from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
-from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
-from tensorflow.contrib.autograph.operators.data_structures import new_list
-from tensorflow.contrib.autograph.operators.py_builtins import float_
-from tensorflow.contrib.autograph.operators.py_builtins import int_
-from tensorflow.contrib.autograph.operators.py_builtins import len_
-from tensorflow.contrib.autograph.operators.py_builtins import print_
-from tensorflow.contrib.autograph.operators.py_builtins import range_
-from tensorflow.contrib.autograph.operators.slices import get_item
-from tensorflow.contrib.autograph.operators.slices import GetItemOpts
-from tensorflow.contrib.autograph.operators.slices import set_item
+from tensorflow.python.autograph.operators.control_flow import for_stmt
+from tensorflow.python.autograph.operators.control_flow import while_stmt
+from tensorflow.python.autograph.operators.data_structures import list_append
+from tensorflow.python.autograph.operators.data_structures import list_pop
+from tensorflow.python.autograph.operators.data_structures import list_stack
+from tensorflow.python.autograph.operators.data_structures import ListPopOpts
+from tensorflow.python.autograph.operators.data_structures import ListStackOpts
+from tensorflow.python.autograph.operators.data_structures import new_list
+from tensorflow.python.autograph.operators.py_builtins import float_
+from tensorflow.python.autograph.operators.py_builtins import int_
+from tensorflow.python.autograph.operators.py_builtins import len_
+from tensorflow.python.autograph.operators.py_builtins import print_
+from tensorflow.python.autograph.operators.py_builtins import range_
+from tensorflow.python.autograph.operators.slices import get_item
+from tensorflow.python.autograph.operators.slices import GetItemOpts
+from tensorflow.python.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 9a66a6bb60..6eedd695a7 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
index 677b7f8f62..bb214b6f16 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import control_flow
+from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
index cc0a3c3544..cc0a3c3544 100644
--- a/tensorflow/contrib/autograph/operators/data_structures.py
+++ b/tensorflow/python/autograph/operators/data_structures.py
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index 4b1e835d44..8532dbe466 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/operators/dispatch_context.py b/tensorflow/python/autograph/operators/dispatch_context.py
index 097002465b..097002465b 100644
--- a/tensorflow/contrib/autograph/operators/dispatch_context.py
+++ b/tensorflow/python/autograph/operators/dispatch_context.py
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index c5730934e7..1d37ae72d3 100644
--- a/tensorflow/contrib/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -23,8 +23,8 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.autograph.utils import tensors
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index 4073c51785..a021263ffa 100644
--- a/tensorflow/contrib/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -22,8 +22,8 @@ import sys
import six
-from tensorflow.contrib.autograph.operators import data_structures
-from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/python/autograph/operators/slices.py
index 04fbeb2f6e..2b7f5ad922 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/python/autograph/operators/slices.py
@@ -22,6 +22,7 @@ import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
@@ -57,6 +58,8 @@ def get_item(target, i, opts):
elif tensor_util.is_tensor(target):
if target.dtype == dtypes.variant:
return _tf_tensor_list_get_item(target, i, opts)
+ elif target.dtype == dtypes.string and target.shape.ndims == 0:
+ return _tf_tensor_string_get_item(target, i)
else:
return _tf_tensor_get_item(target, i)
else:
@@ -82,6 +85,12 @@ def _tf_tensor_get_item(target, i):
return target[i]
+def _tf_tensor_string_get_item(target, i):
+ """Overload of get_item that stages a Tensor string read."""
+ x = gen_string_ops.substr(target, i, 1)
+ return x
+
+
def _py_get_item(target, i):
"""Overload of get_item that executes a Python list modification."""
return target[i]
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
index 56aafe07c8..d8b8418750 100644
--- a/tensorflow/contrib/autograph/operators/slices_test.py
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.autograph.operators import slices
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
@@ -46,6 +46,21 @@ class SlicesTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4])
+ def test_get_item_tensor_string(self):
+ initial_str = constant_op.constant('abcd')
+ t = slices.get_item(initial_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'b')
+
+ initial_list_str = constant_op.constant(['abcd', 'bcde'])
+ t = slices.get_item(initial_list_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'bcde')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
index ddadc6b96e..ddadc6b96e 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/python/autograph/pyct/BUILD
diff --git a/tensorflow/contrib/autograph/pyct/__init__.py b/tensorflow/python/autograph/pyct/__init__.py
index d787e56bbe..d787e56bbe 100644
--- a/tensorflow/contrib/autograph/pyct/__init__.py
+++ b/tensorflow/python/autograph/pyct/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index 1a52110ef3..1a52110ef3 100644
--- a/tensorflow/contrib/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/python/autograph/pyct/anno_test.py
index 5ef4da61a3..1f873871c6 100644
--- a/tensorflow/contrib/autograph/pyct/anno_test.py
+++ b/tensorflow/python/autograph/pyct/anno_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import ast
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
index d7453b0781..7df3b8858c 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -22,8 +22,8 @@ import ast
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
class CleanCopier(object):
diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py
index 2293c89720..b1577c466e 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util_test.py
+++ b/tensorflow/python/autograph/pyct/ast_util_test.py
@@ -22,11 +22,11 @@ import ast
import collections
import textwrap
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index ba51dcf285..1433f9ac83 100644
--- a/tensorflow/contrib/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -33,7 +33,7 @@ from enum import Enum
import gast
# pylint:enable=g-bad-import-order
-from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import compiler
class Node(object):
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
index 9d0a85d615..bd82e70f7d 100644
--- a/tensorflow/contrib/autograph/pyct/cfg_test.py
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD
index fe630ef852..5e2f8f3ac0 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD
@@ -26,7 +26,7 @@ py_library(
"@six_archive//:six",
# TODO(aqj) Revisit this dependency direction when pyct is more
# modularized
- "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py
index d77c15915b..192621b1cd 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py
@@ -29,8 +29,8 @@ from __future__ import print_function
import gast
import six
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
class DummyGensym(object):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
index 1ffd4bbe55..ccc7e4ca8f 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -20,10 +20,10 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.common_transformers import anf
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.common_transformers import anf
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
index f9cee10962..9e1b6bdbe8 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -30,7 +30,7 @@ import tempfile
import astor
import gast
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
def ast_to_source(node, indentation=' '):
diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py
index cf783da6a3..6fa289d3cc 100644
--- a/tensorflow/contrib/autograph/pyct/compiler_test.py
+++ b/tensorflow/python/autograph/pyct/compiler_test.py
@@ -22,8 +22,8 @@ import textwrap
import gast
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index eef74599a7..eef74599a7 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index 1a212f676a..f3eb027822 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -22,7 +22,7 @@ from functools import wraps
import six
-from tensorflow.contrib.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index b60651a30e..4c7c4165ef 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -23,9 +23,9 @@ import tokenize
import gast
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index eeaa13007e..6b9c30dbd0 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index 112ed46a1e..112ed46a1e 100644
--- a/tensorflow/contrib/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
diff --git a/tensorflow/contrib/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
index 007a4c6fb0..d0b465eb73 100644
--- a/tensorflow/contrib/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
index bacc1e4a77..bacc1e4a77 100644
--- a/tensorflow/contrib/autograph/pyct/pretty_printer.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer.py
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py b/tensorflow/python/autograph/pyct/pretty_printer_test.py
index 0cb48f3576..1c76744547 100644
--- a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import ast
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py
index fb81404edc..334cbd7d38 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/python/autograph/pyct/qual_names.py
@@ -29,8 +29,8 @@ import collections
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
class Symbol(collections.namedtuple('Symbol', ['name'])):
diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py
index c793c2bb39..2da4dfd787 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names_test.py
+++ b/tensorflow/python/autograph/pyct/qual_names_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.qual_names import resolve
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.qual_names import resolve
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD
index 92eacba3fd..4a4ccdcbd1 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD
@@ -27,9 +27,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
],
)
@@ -41,8 +41,8 @@ py_test(
tags = ["no_windows"],
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
@@ -54,8 +54,8 @@ py_test(
tags = ["no_windows"],
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -65,8 +65,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -76,8 +76,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -87,8 +87,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
index 9a82de735d..9a82de735d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index a0182da9d1..9cb5991322 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -25,10 +25,10 @@ import copy
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
# TODO(alexbw): Ignore named literals (e.g. None)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index e940516190..d4a6ce8ac3 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/python/autograph/pyct/static_analysis/annos.py
index 5eefecf278..5eefecf278 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/annos.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index e7baa244b2..48b442f3bd 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -25,9 +25,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(aqj): Do we need this? Do other builtins fail in similar ways
# See b/114389775 for a related bug in pyct
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
index fe3051179c..882c380b78 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
@@ -20,15 +20,15 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
index bf29d868a2..41c903beb9 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -26,10 +26,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
class Analyzer(cfg.GraphVisitor):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index d53adb28af..0d5f369e92 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
index 7f2b379d3d..9aaf318a9f 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
@@ -30,10 +30,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
class Definition(object):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index 243fe804b2..373a2cb38f 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
index 835d5199fa..edb2ef0e27 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
@@ -43,9 +43,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
index 404311ba24..34ba3d2f13 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
from tensorflow.python.client import session
from tensorflow.python.platform import test
from tensorflow.python.training import training
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 5831d57ceb..68c2a35fac 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -26,10 +26,10 @@ import textwrap
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
class ReplaceTransformer(gast.NodeTransformer):
@@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
self._check_has_context(node)
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._check_inner_children_have_context(e)
self._check_has_context(node)
@@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._set_inner_child_context(node.value, gast.Load())
node.ctx = ctx
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._set_inner_child_context(e, ctx)
node.ctx = ctx
@@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer):
# Preserve the target context.
for n in new_nodes:
- if isinstance(n, gast.Tuple):
+ if isinstance(n, (gast.Tuple, gast.List)):
for e in n.elts:
self._set_inner_child_context(e, node.ctx)
if isinstance(n, gast.Attribute):
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 77e8ff62fd..66268cfaad 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -22,9 +22,9 @@ import imp
import gast
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
from tensorflow.python.platform import test
@@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
+ def test_replace_list_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_tuple_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_complex_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD
index 29a92444bb..c244cbd747 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/python/autograph/pyct/testing/BUILD
@@ -22,8 +22,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
],
)
@@ -41,8 +41,8 @@ py_test(
],
deps = [
":testing",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/python/autograph/pyct/testing/codegen.py
index 279e7c09dc..78b24390c3 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen.py
@@ -24,7 +24,7 @@ import string
import gast
import numpy as np
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import templates
class NodeSampler(object):
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/python/autograph/pyct/testing/codegen_test.py
index 255c3b2a2e..71665be039 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct.testing import codegen
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct.testing import codegen
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
index 969ca12244..520f5038da 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -23,9 +23,9 @@ import sys
import gast
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import pretty_printer
class AutographParseError(SyntaxError):
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py
index a37e922a1d..23bf9a8e16 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/python/autograph/pyct/transformer_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD
index 4504a5c7a3..22451d4f3f 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/python/autograph/utils/BUILD
@@ -32,10 +32,10 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:dtypes",
"//tensorflow/python:list_ops",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/autograph/pyct",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py
index 38e0a0a8f0..c781958481 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -18,12 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
-from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
-from tensorflow.contrib.autograph.utils.testing import fake_tf
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
+from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
+from tensorflow.python.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
+from tensorflow.python.autograph.utils.py_func import wrap_py_func
+from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
+from tensorflow.python.autograph.utils.testing import fake_tf
+from tensorflow.python.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/autograph/utils/context_managers.py b/tensorflow/python/autograph/utils/context_managers.py
index 3d150a9581..3d150a9581 100644
--- a/tensorflow/contrib/autograph/utils/context_managers.py
+++ b/tensorflow/python/autograph/utils/context_managers.py
diff --git a/tensorflow/contrib/autograph/utils/context_managers_test.py b/tensorflow/python/autograph/utils/context_managers_test.py
index 42e27724b9..7f0a15b076 100644
--- a/tensorflow/contrib/autograph/utils/context_managers_test.py
+++ b/tensorflow/python/autograph/utils/context_managers_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import context_managers
+from tensorflow.python.autograph.utils import context_managers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import tensor_array_ops
diff --git a/tensorflow/contrib/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py
index 1b06caf0bd..1b06caf0bd 100644
--- a/tensorflow/contrib/autograph/utils/misc.py
+++ b/tensorflow/python/autograph/utils/misc.py
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/python/autograph/utils/misc_test.py
index 71e358c33e..8d2b0d6e13 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/python/autograph/utils/misc_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.misc import alias_tensors
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops.variables import Variable
from tensorflow.python.platform import test
@@ -31,7 +31,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
def test_alias_tensors(self):
@@ -46,7 +46,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_v is v)
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
index 70eef5676f..107c8f7a68 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -18,20 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
+from tensorflow.python.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
-def dynamic_is(left, right):
- # TODO(alexbw) if we're sure we should leave 'is' in place,
- # then change the semantics in converters/logical_expressions.py
- return left is right
-
-
-def dynamic_is_not(left, right):
- return left is not right
-
-
def run_cond(condition, true_fn, false_fn):
"""Type-dependent functional conditional.
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
index f72f8e94a0..2a77c895ce 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -18,9 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.autograph.utils import multiple_dispatch
+from tensorflow.python.autograph.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.platform import test
@@ -28,33 +26,6 @@ from tensorflow.python.platform import test
class MultipleDispatchTest(test.TestCase):
- def test_dynamic_is_python(self):
- a = np.eye(3)
- also_a = a
- not_actually_a = np.eye(3)
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
- def test_dynamic_is_tf(self):
- with Session().as_default():
- a = constant([2.0])
- also_a = a
- not_actually_a = constant([2.0])
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
def test_run_cond_python(self):
true_fn = lambda: (2,)
false_fn = lambda: (3,)
diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/python/autograph/utils/py_func.py
index 11ebfb2e49..11ebfb2e49 100644
--- a/tensorflow/contrib/autograph/utils/py_func.py
+++ b/tensorflow/python/autograph/utils/py_func.py
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/python/autograph/utils/py_func_test.py
index 2468263142..1c220d9492 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/python/autograph/utils/py_func_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
@@ -31,7 +31,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c):
return a + b + c
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, sess.run(result))
@@ -52,7 +52,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b):
return a * b.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
@@ -69,7 +69,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c, d):
return a * b.foo + c * d.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
'c': 11,
'd': TestClass(13)
@@ -89,7 +89,7 @@ class PyFuncTest(test.TestCase):
def test_fn(_):
side_counter[0] += 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual([1], side_counter)
diff --git a/tensorflow/contrib/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py
index 2556f41289..2556f41289 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list.py
+++ b/tensorflow/python/autograph/utils/tensor_list.py
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
index d58489eb68..697c166eb1 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import tensor_list as tl
+from tensorflow.python.autograph.utils import tensor_list as tl
from tensorflow.python.client.session import Session
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@@ -42,18 +42,18 @@ class TensorListTest(test.TestCase):
l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(l[0]), 1)
def test_list_append_python(self):
@@ -107,7 +107,7 @@ class TensorListTest(test.TestCase):
l0 = l[0]
l[0] = b
l1 = l[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l0, l1, a, b = sess.run([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/python/autograph/utils/tensors.py
index fa5db81a71..fa5db81a71 100644
--- a/tensorflow/contrib/autograph/utils/tensors.py
+++ b/tensorflow/python/autograph/utils/tensors.py
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/python/autograph/utils/tensors_test.py
index e855e0b6cb..1e7cfec9e1 100644
--- a/tensorflow/contrib/autograph/utils/tensors_test.py
+++ b/tensorflow/python/autograph/utils/tensors_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.autograph.utils import tensors
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
diff --git a/tensorflow/contrib/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py
index cb4785d0dc..cb4785d0dc 100644
--- a/tensorflow/contrib/autograph/utils/testing.py
+++ b/tensorflow/python/autograph/utils/testing.py
diff --git a/tensorflow/contrib/autograph/utils/type_check.py b/tensorflow/python/autograph/utils/type_check.py
index 8748abc47b..8748abc47b 100644
--- a/tensorflow/contrib/autograph/utils/type_check.py
+++ b/tensorflow/python/autograph/utils/type_check.py
diff --git a/tensorflow/contrib/autograph/utils/type_check_test.py b/tensorflow/python/autograph/utils/type_check_test.py
index 3b67b7194c..b3d1304e16 100644
--- a/tensorflow/contrib/autograph/utils/type_check_test.py
+++ b/tensorflow/python/autograph/utils/type_check_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy
-from tensorflow.contrib.autograph.utils import type_check
+from tensorflow.python.autograph.utils import type_check
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index f87a96e547..4afc6399d5 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1762,7 +1762,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
feed_fn1, feed_fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np1 = np.array([1.0, 1.5, 2.0, 2.5])
np2 = np.array([3.0, 3.5, 4.0, 4.5])
squared_tensor = SquaredTensor(np2)
@@ -1922,7 +1922,7 @@ class SessionTest(test_util.TensorFlowTestCase):
pass
def testAutoConvertAndCheckData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = array_ops.placeholder(dtype=dtypes.string)
with self.assertRaisesRegexp(
TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index c046e9cfd4..03effde098 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -161,7 +161,7 @@ class TimelineTest(test.TestCase):
cpu_max = maximums[
'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums[cpuname]
# At least num1 + num2, both float32s (4 bytes each)
- self.assertGreater(cpu_max.num_bytes, 8)
+ self.assertGreaterEqual(cpu_max.num_bytes, 8)
self.assertGreater(cpu_max.timestamp, 0)
self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 7a3fc27592..8a100fe975 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 7)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 14)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 89de55dd4f..c48708a2b9 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -82,7 +82,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -111,7 +111,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -131,7 +131,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -158,7 +158,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -188,7 +188,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
@@ -214,7 +214,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -262,7 +262,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.DataLossError):
sess.run(get_next)
@@ -318,7 +318,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -342,7 +342,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(
@@ -381,7 +381,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
padded_dataset = dataset.padded_batch(
2, padded_shapes=([None], [None]), padding_values=('', 0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_element = padded_dataset.make_one_shot_iterator().get_next()
sess.run(next_element)
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index 4f7fd3566e..d5f5b2fe05 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -68,7 +68,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First run without caching to collect the "ground truth".
sess.run(init_fifo_op)
elements = []
@@ -132,7 +132,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next1 = iterator1.get_next()
get_next2 = iterator2.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
sess.run(get_next1) # this should succeed
@@ -162,7 +162,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next1 = iterator1.get_next()
get_next2 = iterator2.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
elements = []
@@ -217,7 +217,7 @@ class MemoryCacheDatasetTest(test.TestCase):
uncached_iterator = uncached_dataset.make_initializable_iterator()
uncached_next = uncached_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(repeat_count.initializer)
sess.run(cached_iterator.initializer)
@@ -261,7 +261,7 @@ class MemoryCacheDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an empty upstream and a missing cache file (should
# throw errors.OutOfRangeError immediately).
sess.run(init_cache_op, feed_dict={count_placeholder: 0})
@@ -278,7 +278,7 @@ class MemoryCacheDatasetTest(test.TestCase):
i1 = d1.make_initializable_iterator()
i2 = d2.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(i1.initializer)
self.assertEqual(1, sess.run(i1.get_next()))
@@ -304,7 +304,7 @@ class MemoryCacheDatasetTest(test.TestCase):
expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i, expected in enumerate(expected_values):
self.assertEqual(expected, sess.run(n),
"Unexpected value at index %s" % i)
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 159218c99b..5dfb84f28e 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -49,7 +49,7 @@ class ConcatenateDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(9):
result = sess.run(get_next)
@@ -83,7 +83,7 @@ class ConcatenateDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(9):
result = sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index ea5b41e5d8..e43564a2eb 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -50,7 +50,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -84,7 +84,7 @@ class DatasetConstructorTest(test.TestCase):
[tensor_shape.TensorShape(c.dense_shape) for c in components],
[shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -115,7 +115,7 @@ class DatasetConstructorTest(test.TestCase):
if sparse_tensor.is_sparse(c) else c.shape for c in components
], [shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -142,7 +142,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(4):
results = sess.run(get_next)
@@ -172,7 +172,7 @@ class DatasetConstructorTest(test.TestCase):
[tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
[shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
@@ -232,7 +232,7 @@ class DatasetConstructorTest(test.TestCase):
if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
], [shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
@@ -283,7 +283,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual((), iterator.output_shapes["foo"])
self.assertEqual((1,), iterator.output_shapes["bar"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(3):
results = sess.run(get_next)
@@ -300,7 +300,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
# Test with sparse tensor in the appropriate order.
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index fb55ae1400..cd0c1ddf1e 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -44,7 +44,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
sess.run(init_op)
for _ in range(num_repeats):
@@ -61,7 +61,7 @@ class DatasetConstructorTest(test.TestCase):
.make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
@@ -131,7 +131,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
@@ -190,7 +190,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
@@ -213,7 +213,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual(dtype, get_next.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for expected in [[1], [2], [3]]:
next_val = sess.run(get_next)
@@ -234,7 +234,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for expected in [b"foo", b"bar", b"baz"]:
next_val = sess.run(get_next)
@@ -255,7 +255,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -278,7 +278,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -302,7 +302,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((1, 2), sess.run(get_next))
self.assertEqual((3, 4), sess.run(get_next))
@@ -327,7 +327,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(1, sess.run(get_next))
self.assertAllEqual([2, 3], sess.run(get_next))
@@ -347,7 +347,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(0, sess.run(get_next))
self.assertAllEqual(1, sess.run(get_next))
@@ -405,7 +405,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
for x in expected:
@@ -434,7 +434,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [(0, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"),
@@ -468,7 +468,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(37, sess.run(get_next))
self.assertAllEqual(37, sess.run(get_next))
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 2c4c11e132..239aa85175 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -27,7 +27,7 @@ class DatasetOpsTest(test.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 4f2216f0a3..19944d389f 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -59,7 +59,7 @@ class FilterDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test that we can dynamically feed a different modulus value for each
# iterator.
def do_test(count_val, modulus_val):
@@ -84,7 +84,7 @@ class FilterDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
self.assertEqual(1, sess.run(get_next))
self.assertEqual(3, sess.run(get_next))
@@ -98,7 +98,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
if (i ** 2) % 2 == 0:
@@ -123,7 +123,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(input_data[0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -151,7 +151,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(5):
actual = sess.run(get_next)
@@ -169,7 +169,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, True), sess.run(get_next))
@@ -181,7 +181,7 @@ class FilterDatasetTest(test.TestCase):
lambda x: math_ops.equal(x % 2, 0))
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
next_elements = [iterator.get_next() for iterator in iterators]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 350234a839..1123cbff62 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -43,7 +43,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in repeats:
for _ in range(i):
@@ -62,7 +62,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for row in repeats:
for i in row:
@@ -113,7 +113,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for _ in range(i ** 2):
@@ -137,7 +137,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index 579096f880..c4b338a58f 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -44,7 +44,7 @@ class ListFilesDatasetOpTest(test.TestCase):
def testEmptyDirectory(self):
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@ class ListFilesDatasetOpTest(test.TestCase):
self._touchTempFiles(filenames)
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
@@ -75,7 +75,7 @@ class ListFilesDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
@@ -91,7 +91,7 @@ class ListFilesDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
@@ -121,7 +121,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
with self.assertRaisesRegexp(
errors.InvalidArgumentError, 'No files matched pattern: '):
@@ -136,7 +136,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -162,7 +162,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -187,7 +187,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -221,7 +221,7 @@ class ListFilesDatasetOpTest(test.TestCase):
# more meaningful.
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index fde785be6e..7685d8dbdc 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={count: 14})
for _ in range(14):
@@ -138,7 +138,8 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
def do_test(num_parallel_calls_val, output_buffer_size_val):
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={
@@ -203,7 +204,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -218,7 +219,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -233,7 +234,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -254,7 +255,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -285,7 +286,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(table.init)
sess.run(init_op)
sess.run(get_next)
@@ -303,7 +304,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
@@ -328,7 +329,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
@@ -347,7 +348,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(counter_var.initializer)
sess.run(init_op)
for i in range(10):
@@ -367,7 +368,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.NotFoundError):
sess.run(get_next)
@@ -379,7 +380,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
random_values = []
with self.assertRaises(errors.OutOfRangeError):
@@ -404,7 +405,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
@@ -436,7 +437,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
# make sure both datasets contain the same data
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(count):
tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
self.assertEqual(tuple_, namedtuple_)
@@ -454,7 +455,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(row ** 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -485,7 +486,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Simple test that prefetch yields the expected values in the
# expected order.
for buffer_size in [1, 10, 100, 1000]:
@@ -523,7 +524,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
@@ -544,7 +545,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
@@ -570,7 +571,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
@@ -597,7 +598,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
@@ -621,7 +622,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(100):
self.assertEqual(i, sess.run(get_next))
@@ -635,7 +636,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, b"hello", 10), sess.run(get_next))
@@ -702,7 +703,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
dataset = dataset.map(broken_function)
iterator = dataset.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
sess.run(iterator.initializer)
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index a32527af8d..c344513e71 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -158,7 +158,7 @@ class OptionalTest(test.TestCase):
self.assertEqual(ds.output_classes, next_elem.output_classes)
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index 63a0830272..cc97bac609 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -36,7 +36,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
for m in range(10):
self.assertEqual(m, sess.run(get_next))
@@ -51,7 +51,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index ad87f31b01..51e90785e7 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -49,7 +49,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={stop: 5})
for i in range(5):
self.assertEqual(i, sess.run(get_next))
@@ -64,7 +64,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 5})
for i in range(2, 5):
self.assertEqual(i, sess.run(get_next))
@@ -80,7 +80,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
for i in range(2, 10, 2):
self.assertEqual(i, sess.run(get_next))
@@ -95,7 +95,7 @@ class RangeDatasetTest(test.TestCase):
step).make_initializable_iterator()
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
@@ -108,7 +108,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -125,7 +125,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -143,7 +143,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -161,7 +161,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
for i in range(10, 2, -1):
self.assertEqual(i, sess.run(get_next))
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index 431362aa9a..aa3636364d 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -100,7 +100,7 @@ class TextLineDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -163,7 +163,7 @@ class TextLineDatasetTest(test.TestCase):
repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
iterator = repeat_dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
@@ -240,7 +240,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -302,7 +302,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
@@ -319,7 +319,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
@@ -661,7 +661,7 @@ class TFRecordDatasetTest(test.TestCase):
return filenames
def testReadOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
self.init_op,
@@ -698,7 +698,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: self.test_filenames,
@@ -711,7 +711,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochsOfBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_batch_op,
feed_dict={
@@ -738,7 +738,7 @@ class TFRecordDatasetTest(test.TestCase):
f.write(cdata)
zlib_files.append(zfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: zlib_files,
@@ -758,7 +758,7 @@ class TFRecordDatasetTest(test.TestCase):
gzf.write(f.read())
gzip_files.append(gzfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: gzip_files,
@@ -774,7 +774,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -786,7 +786,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(files)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -801,7 +801,7 @@ class TFRecordDatasetTest(test.TestCase):
next_element = iterator.get_next()
expected = []
actual = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 1d27b036eb..37e2333560 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -44,7 +44,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test a finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 3})
for _ in range(3):
@@ -90,7 +90,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Take fewer than input size
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4):
@@ -136,7 +136,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Skip fewer than input size, we should skip
# the first 4 elements and then read the rest.
sess.run(init_op, feed_dict={count_placeholder: 4})
@@ -183,7 +183,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
for _ in range(7 * 14):
results = sess.run(get_next)
@@ -199,7 +199,7 @@ class SequenceDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index cefe872d0f..137f6341ce 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -28,7 +28,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(2, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -40,7 +40,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((2, 8), sess.run(iterator.get_next()))
self.assertEqual((7, 3), sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -50,7 +50,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 0)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(iterator.get_next()))
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -76,14 +76,14 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
def testLargerWorkerPool(self):
dataset = dataset_ops.Dataset.range(10).shard(7, 5)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
@@ -91,7 +91,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 4)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(4, sess.run(iterator.get_next()))
self.assertEqual(9, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -100,7 +100,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards2(self):
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 5fcc48831f..f294840706 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -60,7 +60,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First run without shuffling to collect the "ground truth".
sess.run(init_fifo_op)
unshuffled_elements = []
@@ -140,7 +140,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
elems = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
elems.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -152,7 +152,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_initializable_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
for elem in elems:
self.assertEqual(elem, sess.run(get_next))
@@ -166,7 +166,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
counts = collections.defaultdict(lambda: 0)
for _ in range(10):
for _ in range(5):
@@ -183,7 +183,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = sess.run(next_element)
self.assertAllEqual(initial_permutation, sess.run(next_element))
self.assertAllEqual(initial_permutation, sess.run(next_element))
@@ -198,7 +198,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = list(sess.run(next_element))
for _ in range(2):
next_permutation = list(sess.run(next_element))
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 55933118b9..3106effbd3 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -45,7 +45,7 @@ class ZipDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
@@ -93,7 +93,7 @@ class ZipDatasetTest(test.TestCase):
self.assertEqual([22], get_next[1][0].shape)
self.assertEqual([], get_next[1][1].shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 6a67093e48..89c3afb296 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -30,28 +30,28 @@ class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
def testPartialShapeToTensorKnownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([1]))))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
@@ -60,7 +60,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([1], dtype=dtypes.int64))))
def testPartialShapeToTensorUnknownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None]))))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
@@ -84,7 +84,7 @@ class ConvertTest(test.TestCase):
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
def testPartialShapeToTensorMultipleDimensions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, 6]))))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
@@ -113,7 +113,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64))))
def testPartialShapeToTensorScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([]))))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 3a5d1f0adf..e5abc654da 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -99,7 +99,6 @@ def _yield_value(iterable):
# See the swig file (../../util/util.i) for documentation.
is_sequence = _pywrap_tensorflow.IsSequenceForData
-
# See the swig file (../../util/util.i) for documentation.
flatten = _pywrap_tensorflow.FlattenForData
diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py
index d49b3ff34b..056b32480f 100644
--- a/tensorflow/python/data/util/sparse_test.py
+++ b/tensorflow/python/data/util/sparse_test.py
@@ -291,7 +291,7 @@ class SparseTest(test.TestCase):
self.assertEqual(a, b)
return
self.assertTrue(isinstance(b, sparse_tensor.SparseTensor))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(a.eval().indices, b.eval().indices)
self.assertAllEqual(a.eval().values, b.eval().values)
self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 85da1baaf0..c1bc27d443 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -345,6 +345,7 @@ py_test(
deps = [
":backprop",
":context",
+ ":core",
":test",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3bdaf0b214..3fe79ef244 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
+ def benchmarkScan(self):
+ elems = math_ops.range(1600)
+
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
+ def benchmarkScanDefun(self):
+ elems = math_ops.range(1600)
+
+ @function.defun
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 03f12139f6..962e334b27 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -27,6 +27,7 @@ import threading
import numpy as np
import six
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -34,6 +35,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
@@ -59,6 +61,10 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
+WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+
+
def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
@@ -99,6 +105,44 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+def _parse_func_attrs(attributes):
+ """Convert the keyword arguments into function_def attributes.
+
+ Currently only support primitive types: bool, int, float and string.
+
+ Args:
+ attributes: the dictionary of attributes.
+ Returns:
+ A dict of attributes where the key is the name of attribute and the value
+ is the AttrValue proto.
+ Raises:
+ ValueError: If the kwargs contains unwhitelisted name or unsupported value
+ types.
+ """
+ attrs = {}
+ for key, value in attributes.items():
+ if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ raise ValueError("Attribute name is not whitelisted. "
+ "Whitelisted: prefix %s, got: %s" %
+ (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+
+ if isinstance(value, attr_value_pb2.AttrValue):
+ attrs[key] = value
+ # bool type check has to happen before int since bool is a subclass of int.
+ elif isinstance(value, bool):
+ attrs[key] = attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ attrs[key] = attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ attrs[key] = attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (key, type(value)))
+ return attrs
+
+
class FuncGraph(ops.Graph):
"""Graph representing a function body.
@@ -485,7 +529,7 @@ class Function(object):
self._num_outputs = len(self._func_graph.outputs)
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
- self._attrs = attrs or {}
+ self._attrs = _parse_func_attrs(attrs or {})
self._device_functions = tuple(
self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
@@ -879,9 +923,6 @@ def _encode_arg(arg):
_TensorType(arg.values.dtype, arg.values._shape_tuple()),
_TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- elif isinstance(arg, np.ndarray):
- tensor = ops.convert_to_tensor(arg)
- return _TensorType(tensor.dtype, tensor._shape_tuple())
# pylint: enable=protected-access
elif isinstance(arg, (list, tuple)):
return tuple([_encode_arg(elem) for elem in arg])
@@ -911,7 +952,8 @@ class PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None):
+ input_signature=None,
+ attributes=None):
"""Initializes a polymorphic function.
Args:
@@ -920,6 +962,8 @@ class PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
+ attributes: dict, extra keyword arguments that will be added as attribute
+ of the function.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -937,6 +981,7 @@ class PolymorphicFunction(object):
self._name = name
self._function_cache = collections.OrderedDict()
self._variables = []
+ self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -1089,6 +1134,17 @@ class PolymorphicFunction(object):
# opposed to named arguments called in a keyword-like fashion.
kwds.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ flat_inputs = nest.flatten(inputs)
+
+ # Check for NumPy arrays in arguments and convert them to Tensors.
+ need_packing = False
+ for index, value in enumerate(flat_inputs):
+ if isinstance(value, np.ndarray):
+ flat_inputs[index] = constant_op.constant(value)
+ need_packing = True
+ if need_packing:
+ inputs = nest.pack_sequence_as(structure=inputs,
+ flat_sequence=flat_inputs)
if self._input_signature is None:
return inputs, kwds
else:
@@ -1098,7 +1154,6 @@ class PolymorphicFunction(object):
except (ValueError, TypeError):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
- flat_inputs = nest.flatten(inputs)
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
@@ -1141,13 +1196,42 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
+ kwds, self._input_signature),
+ self._function_attributes)
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
return graph_function, (args, kwds)
+def register(func, *args, **kwargs):
+ """Register the defun function into the graph.
+
+ This won't actually call the function with the inputs, and only put the
+ function definition into graph. Register function with different input param
+ will result into multiple version of functions registered in graph.
+
+ Args:
+ func: the PolymorphicFunction instance that generated by a @defun
+ *args: input arguments for the Python function.
+ **kwargs: input keyword arguments for the Python function.
+
+ Returns:
+ a `Function` object specialized to inputs and execution context.
+
+ Raises:
+ ValueError: When the input function is not a defun wrapped python function.
+ """
+ if not isinstance(func, PolymorphicFunction):
+ raise ValueError("Only defun function is allowed to be registered. "
+ "Got type: %s" % type(func))
+ concrete_func = func.get_concrete_function(*args, **kwargs)
+ graph = ops.get_default_graph()
+ concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
+ # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+ return concrete_func
+
+
def _validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature)):
@@ -1271,6 +1355,11 @@ def defun(func=None, input_signature=None):
tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
input signature inferred from `(*args, **kwargs)` and cached for future reuse.
+ NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
+ before being passed to `f`, and are treated as Tensors for caching. This
+ allows a function to be called multiple times with NumPy arrays having
+ different values but the same shape and dtype without re-tracing each time.
+
`tf.contrib.eager.defun` caches graphs for your convenience, letting you
define TensorFlow functions without explicitly specifying their signatures.
However, this policy is conservative and potentially expensive; for example,
@@ -1470,7 +1559,29 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
+ return defun_with_attributes(func=func, input_signature=input_signature)
+
+def defun_with_attributes(func=None, input_signature=None, attributes=None):
+ """Compiles a Python function into a callable TensorFlow graph.
+
+ This function supports adding extra function attributes. See detailed
+ documentation in defun(). Currently this is not exposed in public API since we
+ don't expect user to directly use attributes, and attribute won't work by
+ itself. This assumption might change in future.
+
+ Args:
+ func: function to be compiled.
+ input_signature: same as defun()'s input_signature.
+ attributes: A dictionary of arguments which will be added to function def as
+ attributes. Currently only support primitive types as value, and only
+ whitelisted attribute name is allowed. Unwhitelisted attribute name or
+ unsupported value will result into ValueError.
+
+ Returns:
+ Same as the return value of defun, with attributes added to the function in
+ graph.
+ """
if input_signature is not None:
_validate_signature(input_signature)
@@ -1482,7 +1593,8 @@ def defun(func=None, input_signature=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature))
+ PolymorphicFunction(function, name, input_signature=input_signature,
+ attributes=attributes))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 37a9957cea..a0abefe666 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -22,7 +22,10 @@ import functools
from multiprocessing.pool import ThreadPool
import sys
+import numpy
+
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -36,6 +39,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -55,6 +59,21 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name='')
+ self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+ bias_initializer='ones')
+
+ def call(self, inputs, training=True):
+ return self.fc(inputs)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -104,7 +123,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(step(), 2.0)
def testGraphGradientVariable(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
@@ -211,7 +230,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(f(), x)
def testSymGradGatherNd(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
@function.defun
def f(x):
@@ -314,6 +333,7 @@ class FunctionTest(test.TestCase):
def testDefunNumpyArraysConvertedToTensors(self):
def f(x):
+ self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
@@ -327,6 +347,12 @@ class FunctionTest(test.TestCase):
# shouldn't trigger another function definition.
self.assertEqual(len(defined._function_cache), 1)
+ # Test that the numpy array is properly an argument to the graph function.
+ self.assertEqual(1., defined(numpy.ones([])).numpy())
+ self.assertEqual(0., defined(numpy.zeros([])).numpy())
+ self.assertEqual(1., defined(array_ops.ones([])).numpy())
+ self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -481,7 +507,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
def testGraphModeCaptureVariable(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
class HasAVar(object):
@@ -509,12 +535,12 @@ class FunctionTest(test.TestCase):
x = constant_op.constant(1.0)
l = f(x, v)
_, dv = gradients_impl.gradients(l, [x, v])
- with self.test_session():
+ with self.cached_session():
v.initializer.run()
self.assertAllEqual(dv.eval(), 0.0)
def testGraphModeManyFunctions(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
@function.defun
def f(x):
@@ -934,7 +960,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@function.defun
@@ -996,6 +1022,7 @@ class FunctionTest(test.TestCase):
with ops.get_default_graph().as_default():
create_variable()
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self):
conv = convolutional.Conv2D(
filters=1,
@@ -1009,7 +1036,34 @@ class FunctionTest(test.TestCase):
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
- self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+ # Remove reference cycles in model
+ test_util.dismantle_polymorphic_function(model)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testDefunKerasModelCall(self):
+ model = MiniModel()
+ model.call = function.defun(model.call)
+
+ x = array_ops.ones([1, 2])
+ y = model(x)
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[3.0]], self.evaluate(y))
+
+ # Remove reference cycles in defun.
+ test_util.dismantle_polymorphic_function(model.call)
+ # Break the reference cycle between the MiniModel and the defun:
+ # MiniModel --(through its `call` method)--> PolymorphicFunction
+ # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+ del model.call
# Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`.
@@ -1492,12 +1546,151 @@ class FunctionTest(test.TestCase):
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
+ def testFunctionWithExtraAttributes(self):
+ @function.defun_with_attributes(attributes={'experimental_1': 'value1',
+ 'experimental_2': 2})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ def add(x, y):
+ return math_ops.add(x, y)
+ defun_add = function.defun_with_attributes(
+ add, attributes={'experimental_3': True, 'experimental_4': 1.0})
+
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t)
+ double = defun_add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ self.assertRegexpMatches(
+ functions[0].definition.signature.name, '.*matmul.*')
+ attrs = functions[0].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_1'].s, b'value1')
+ self.assertEqual(attrs['experimental_2'].i, 2)
+
+ self.assertRegexpMatches(
+ functions[1].definition.signature.name, '.*add.*')
+ attrs = functions[1].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_3'].b, True)
+ self.assertEqual(attrs['experimental_4'].f, 1.0)
+ # pylint: enable=protected-access
+
+ def testFunctionWithInvalidAttribute(self):
+ @function.defun_with_attributes(attributes={'attr1': 'value1'})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Attribute name is not whitelisted.*'):
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ matmul(t, t)
+
+ @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Unsupported attribute type.*'):
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ add(t, t)
+
+ def testRegisterFunction(self):
+ @function.defun
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+ function.register(add, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ pre_register_matmul_func_name = functions[0].definition.signature.name
+ self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
+ pre_register_add_func_name = functions[1].definition.signature.name
+ self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+
+ sq = defun_matmul(t, t)
+ double = add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+ # Make sure the pre registered function is used, and no other function
+ # is added.
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ called_func_name = functions[0].definition.signature.name
+ self.assertEqual(pre_register_matmul_func_name, called_func_name)
+ called_func_name = functions[1].definition.signature.name
+ self.assertEqual(pre_register_add_func_name, called_func_name)
+
+ def testRegisterFunctionWithInputSignature(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(
+ matmul,
+ input_signature=[
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
+ ])
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ # Test input param shape mismatch
+ t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ with self.assertRaisesRegexp(
+ ValueError, 'Python inputs incompatible with input_signature'):
+ function.register(defun_matmul, t2, t2)
+
+ def testRegisterFunctionWithCache(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
+ function.register(defun_matmul, t, t)
+ function.register(defun_matmul, t2, t2)
+
+ graph = ops.get_default_graph()
+ # Only one function is registered since the input param are in same type
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
def testBasic(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
with function.AutomaticControlDependencies() as c:
@@ -1508,7 +1701,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(), 4.0)
def testCondMustRun(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1529,7 +1722,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
def testCondMustRunSeparateRead(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1552,7 +1745,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(v.read_value().eval(), 6.0)
def testCondNested(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1586,7 +1779,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
def testCondOneBranch(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1606,7 +1799,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
def testCondOneBranchUpdateBefore(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1627,7 +1820,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
def testCondOneBranchUpdateAfter(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1663,7 +1856,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(out, [3, 4, 5])
def testDecorator(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
index d2a2b4e223..3cf3a61a62 100644
--- a/tensorflow/python/eager/graph_only_ops_test.py
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
def testGraphZerosLike(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
z_tf = graph_only_ops.graph_zeros_like(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
def testGraphPlaceholder(self):
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
y_tf = math_ops.square(x_tf)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = np.array([42])
y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
self.assertAllClose(np.square(x), y)
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 1ed814258b..9f2f4e06ad 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1403,9 +1403,13 @@ class PyVSpace
PyObject* arglist =
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
tensorflow::int64 r = MakeInt(result);
Py_DECREF(result);
- Py_DECREF(arglist);
return r;
}
@@ -1740,117 +1744,167 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
-bool OpDoesntRequireOutput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "MatMul",
- "Conv2DBackpropInput",
- "Conv2DBackpropFilter",
- "Conv3D",
- "Conv3DBackpropInputV2",
- "AvgPool3D",
- "AvgPool3DGrad",
- "MaxPool3D",
- "MaxPool3DGrad",
- "MaxPool3DGradGrad",
- "BiasAdd",
- "BiasAddV1",
- "BiasAddGrad",
- "Softplus",
- "SoftplusGrad",
- "Softsign",
- "ReluGrad",
- "Conv2D",
- "DepthwiseConv2dNative",
- "Dilation2D",
- "AvgPool",
- "AvgPoolGrad",
- "BatchNormWithGlobalNormalization",
- "L2Loss",
- "Sum",
- "Prod",
- "SegmentSum",
- "SegmentMean",
- "SparseSegmentSum",
- "SparseSegmentMean",
- "SparseSegmentSqrtN",
- "SegmentMin",
- "SegmentMax",
- "UnsortedSegmentSum",
- "UnsortedSegmentMax",
- "Abs",
- "Neg",
- "ReciprocalGrad",
- "Square",
- "Expm1",
- "Log",
- "Log1p",
- "TanhGrad",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Cos",
- "Tan",
- "Add",
- "Sub",
- "Mul",
- "Div",
- "RealDiv",
- "Maximum",
- "Minimum",
- "SquaredDifference",
- "Select",
- "SparseMatMul",
- "BatchMatMul",
- "Complex",
- "Real",
- "Imag",
- "Angle",
- "Conj",
- "Cast",
- "Cross",
- "Cumsum",
- "Cumprod",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
- "StridedSlice",
+// Returns a pair where the first value of the pair indicates whether or not all
+// outputs are unused. If the first value is false, the second value is a
+// set that identifies which of the output indices are unused.
+bool OpGradientDoesntRequireOutputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any outputs.
+ {"Identity", {true, {}}},
+ {"MatMul", {true, {}}},
+ {"Conv2DBackpropInput", {true, {}}},
+ {"Conv2DBackpropFilter", {true, {}}},
+ {"Conv3D", {true, {}}},
+ {"Conv3DBackpropInputV2", {true, {}}},
+ {"AvgPool3D", {true, {}}},
+ {"AvgPool3DGrad", {true, {}}},
+ {"MaxPool3D", {true, {}}},
+ {"MaxPool3DGrad", {true, {}}},
+ {"MaxPool3DGradGrad", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"BiasAddV1", {true, {}}},
+ {"BiasAddGrad", {true, {}}},
+ {"Softplus", {true, {}}},
+ {"SoftplusGrad", {true, {}}},
+ {"Softsign", {true, {}}},
+ {"ReluGrad", {true, {}}},
+ {"Conv2D", {true, {}}},
+ {"DepthwiseConv2dNative", {true, {}}},
+ {"Dilation2D", {true, {}}},
+ {"AvgPool", {true, {}}},
+ {"AvgPoolGrad", {true, {}}},
+ {"BatchNormWithGlobalNormalization", {true, {}}},
+ {"L2Loss", {true, {}}},
+ {"Sum", {true, {}}},
+ {"Prod", {true, {}}},
+ {"SegmentSum", {true, {}}},
+ {"SegmentMean", {true, {}}},
+ {"SparseSegmentSum", {true, {}}},
+ {"SparseSegmentMean", {true, {}}},
+ {"SparseSegmentSqrtN", {true, {}}},
+ {"SegmentMin", {true, {}}},
+ {"SegmentMax", {true, {}}},
+ {"UnsortedSegmentSum", {true, {}}},
+ {"UnsortedSegmentMax", {true, {}}},
+ {"Abs", {true, {}}},
+ {"Neg", {true, {}}},
+ {"ReciprocalGrad", {true, {}}},
+ {"Square", {true, {}}},
+ {"Expm1", {true, {}}},
+ {"Log", {true, {}}},
+ {"Log1p", {true, {}}},
+ {"TanhGrad", {true, {}}},
+ {"SigmoidGrad", {true, {}}},
+ {"Sign", {true, {}}},
+ {"Sin", {true, {}}},
+ {"Cos", {true, {}}},
+ {"Tan", {true, {}}},
+ {"Add", {true, {}}},
+ {"Sub", {true, {}}},
+ {"Mul", {true, {}}},
+ {"Div", {true, {}}},
+ {"RealDiv", {true, {}}},
+ {"Maximum", {true, {}}},
+ {"Minimum", {true, {}}},
+ {"SquaredDifference", {true, {}}},
+ {"Select", {true, {}}},
+ {"SparseMatMul", {true, {}}},
+ {"BatchMatMul", {true, {}}},
+ {"Complex", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Angle", {true, {}}},
+ {"Conj", {true, {}}},
+ {"Cast", {true, {}}},
+ {"Cross", {true, {}}},
+ {"Cumsum", {true, {}}},
+ {"Cumprod", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"StridedSlice", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of outputs.
+ {"FusedBatchNorm", {false, {0, 1, 2}}},
});
- return ops_that_dont_require_outputs->find(op_name) !=
- ops_that_dont_require_outputs->end();
-}
-
-bool OpDoesntRequireInput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "Softmax",
- "LogSoftmax",
- "BiasAdd",
- "Relu",
- "Relu6",
- "Elu",
- "Selu",
- "SparseSoftmaxCrossEntropyWithLogits",
- "Neg",
- "Inv",
- "Reciprocal",
- "Sqrt",
- "Exp",
- "Tanh",
- "Sigmoid",
- "Real",
- "Imag",
- "Conj",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+// Returns a pair where the first value of the pair indicates whether or not all
+// inputs are unused. If the first value is false, the second value is a
+// set that identifies which of the input indices are unused.
+bool OpGradientDoesntRequireInputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any inputs.
+ {"Identity", {true, {}}},
+ {"Softmax", {true, {}}},
+ {"LogSoftmax", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"Relu", {true, {}}},
+ {"Relu6", {true, {}}},
+ {"Elu", {true, {}}},
+ {"Selu", {true, {}}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
+ {"Neg", {true, {}}},
+ {"Inv", {true, {}}},
+ {"Reciprocal", {true, {}}},
+ {"Sqrt", {true, {}}},
+ {"Exp", {true, {}}},
+ {"Tanh", {true, {}}},
+ {"Sigmoid", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Conj", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of inputs.
+ {"FusedBatchNorm", {false, {2}}},
});
- return ops_that_dont_require_inputs->find(op_name) !=
- ops_that_dont_require_inputs->end();
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+PyObject* CopySequenceSettingIndicesToNull(
+ PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
+ tensorflow::Safe_PyObjectPtr fast_seq(
+ PySequence_Fast(seq, "unable to allocate"));
+ PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
+ for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
+ PyObject* item;
+ if (indices.find(i) != indices.end()) {
+ item = Py_None;
+ } else {
+ item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
+ }
+ Py_INCREF(item);
+ PyTuple_SET_ITEM(result, i, item);
+ }
+ return result;
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
@@ -1870,16 +1924,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
if (!should_record) Py_RETURN_NONE;
string c_op_name = TFE_GetPythonString(op_name);
+
PyObject* op_outputs;
- if (OpDoesntRequireOutput(c_op_name)) {
- op_outputs = Py_None;
+ bool op_outputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
+
+ if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
+ if (outputs_not_required->first) {
+ op_outputs = Py_None;
+ } else {
+ op_outputs_tuple_created = true;
+ op_outputs = CopySequenceSettingIndicesToNull(
+ results, outputs_not_required->second);
+ }
} else {
op_outputs = results;
}
PyObject* op_inputs;
- if (OpDoesntRequireInput(c_op_name)) {
- op_inputs = Py_None;
+ bool op_inputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
+
+ if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
+ if (inputs_not_required->first) {
+ op_inputs = Py_None;
+ } else {
+ op_inputs_tuple_created = true;
+ op_inputs =
+ CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
+ }
} else {
op_inputs = inputs;
}
@@ -1922,6 +1995,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
});
Py_DECREF(num_inputs);
+ if (op_outputs_tuple_created) Py_DECREF(op_outputs);
+ if (op_inputs_tuple_created) Py_DECREF(op_inputs);
Py_RETURN_NONE;
}
@@ -2492,13 +2567,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
+ int delta = 1;
if (!output_arg.number_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.number_attr()];
+ delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.type_list_attr()];
- } else {
- num_retvals++;
+ delta = attr_list_sizes[output_arg.type_list_attr()];
+ }
+ if (delta < 0) {
+ RaiseFallbackException(
+ "Attributes suggest that the size of an output list is less than 0");
+ return nullptr;
}
+ num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index fd8ab695b8..669fa08488 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -123,8 +124,8 @@ class Tests(test.TestCase):
def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
- a_2_by_2 = constant_op.constant(
- [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]],
+ dtype=dtypes.float32)
a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
m2 = resource_variable_ops._MixedPrecisionVariable(
@@ -233,6 +234,26 @@ class Tests(test.TestCase):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
ctx_handle, None, [], a_2_by_2)
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastPathExecute_InvalidAttributes(self):
+ split_dim = constant_op.constant(0, dtype=dtypes.int32)
+ value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
+ ctx = context.context()
+ ctx_handle = ctx._handle
+ with self.assertRaises(core._FallbackException):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+ "Split", None, None, split_dim,
+ value, "num_split", -1)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testInvalidNumOutputs(self):
+ with self.assertRaisesRegexp(
+ Exception,
+ "Value for attr 'num_split' of -1 must be at least minimum 1"):
+ array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 4326d5efa3..acd0e569f1 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -72,7 +72,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_c = tf_a + tf_b
@@ -135,7 +135,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_mm = math_ops.matmul(tf_a, tf_b)
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 00da335fef..bfcc019dd5 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -684,11 +684,7 @@ py_test(
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
- "manual", # b/112769036, b/113907597
- "no_oss", # b/112769036, b/113907597
"no_windows",
- "noasan", # b/114304340
- "nomsan",
"notsan", # b/67510291
],
deps = [
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 08026a93c5..6e28c72151 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -1560,7 +1560,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_classification())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1593,7 +1593,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_classification_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
boosted_trees._create_classification_head(n_classes=2),
@@ -1633,7 +1633,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_classification())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train without train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1666,7 +1666,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_classification_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
boosted_trees._create_classification_head(n_classes=2),
@@ -1704,7 +1704,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_regression())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1734,7 +1734,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_regression_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1774,7 +1774,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_regression())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train without train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1804,7 +1804,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_regression_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index bd2e0ae943..de9c84d2ef 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -260,7 +260,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
features={'x': np.array(((30.,), (42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
logits_placeholder: logits_2x2
@@ -293,7 +293,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -347,14 +347,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError('Labels must <= n_classes - 1'):
training_loss.eval({
labels_placeholder: labels_2x1_with_large_id,
logits_placeholder: logits_2x3
})
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError('Labels must >= 0'):
training_loss.eval({
labels_placeholder: labels_2x1_with_negative_id,
@@ -413,7 +413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -449,7 +449,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -484,7 +484,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
@@ -510,7 +510,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
@@ -534,7 +534,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -561,7 +561,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -581,7 +581,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -632,7 +632,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -698,7 +698,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -727,7 +727,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -755,7 +755,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
}
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -804,7 +804,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -837,7 +837,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -866,7 +866,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -921,7 +921,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -962,7 +962,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
optimizer=_Optimizer())
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -992,7 +992,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
labels=np.array(((1,), (1,)), dtype=np.int64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -1023,7 +1023,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -1064,7 +1064,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1104,7 +1104,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels_rank_1)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1153,7 +1153,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1183,7 +1183,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1211,7 +1211,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
train_op_fn=_train_op_fn)
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss = sess.run(spec.loss)
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1253,7 +1253,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1292,7 +1292,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1327,7 +1327,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1353,7 +1353,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1380,7 +1380,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1413,7 +1413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -1506,7 +1506,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
features={'x': np.array(((42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
logits_placeholder: logits_2x2
@@ -1536,7 +1536,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -1577,7 +1577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[3 1\] \[labels_shape: \] \[2 1\]'):
@@ -1585,7 +1585,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels_placeholder: values_2x1,
logits_placeholder: values_3x1
})
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -1624,7 +1624,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -1660,7 +1660,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
@@ -1680,7 +1680,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1733,7 +1733,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1808,7 +1808,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1832,7 +1832,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(41., training_loss.eval())
@@ -1849,7 +1849,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1877,7 +1877,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1924,7 +1924,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1957,7 +1957,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -1983,7 +1983,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -2011,7 +2011,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -2031,7 +2031,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2086,7 +2086,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2126,7 +2126,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=labels,
optimizer=_Optimizer())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss)
@@ -2153,7 +2153,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=np.array(((1,), (1,),), dtype=np.float64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -2182,7 +2182,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
# Assert summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -2227,7 +2227,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
regularization_losses=regularization_losses)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2254,7 +2254,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'Labels must <= n_classes - 1'):
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
training_loss.eval()
@@ -2277,7 +2277,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2309,7 +2309,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
train_op_fn=_train_op_fn)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)
@@ -2334,7 +2334,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2360,7 +2360,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
expected_loss = 1.2484322
# Assert loss.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2385,7 +2385,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
predictions = sess.run(spec.predictions)
self.assertAllClose(
@@ -2447,7 +2447,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2483,7 +2483,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels_rank_1)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(),
@@ -2531,7 +2531,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNotNone(spec.train_op)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((
@@ -2577,7 +2577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNotNone(spec.train_op)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((
@@ -2612,7 +2612,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(),
@@ -2649,7 +2649,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -2675,7 +2675,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2700,7 +2700,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2744,7 +2744,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2825,7 +2825,7 @@ class RegressionHead(test.TestCase):
features={'x': np.array(((42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval({
logits_placeholder: logits_1d
@@ -2857,7 +2857,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.loss.eval({
labels_placeholder: values_3d,
@@ -2868,7 +2868,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2908,7 +2908,7 @@ class RegressionHead(test.TestCase):
logits=logits_placeholder,
labels=labels_placeholder,
train_op_fn=lambda x: x)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.loss.eval({
labels_placeholder: values_3d,
@@ -2919,7 +2919,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2957,7 +2957,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(logits, spec.predictions[prediction_key].eval())
self.assertAllClose(
@@ -2992,7 +2992,7 @@ class RegressionHead(test.TestCase):
spec.export_outputs.keys())
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
@@ -3019,7 +3019,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(43-45)^2, (44-41)] = [4, 9]
self.assertAllClose(13., training_loss.eval())
@@ -3045,7 +3045,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -3064,7 +3064,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -3112,7 +3112,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3180,7 +3180,7 @@ class RegressionHead(test.TestCase):
}
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -3212,7 +3212,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3237,7 +3237,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3294,7 +3294,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3337,7 +3337,7 @@ class RegressionHead(test.TestCase):
labels=labels,
optimizer=_Optimizer())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss)
@@ -3364,7 +3364,7 @@ class RegressionHead(test.TestCase):
labels=np.array(((43.,), (44.,),), dtype=np.float64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -3394,7 +3394,7 @@ class RegressionHead(test.TestCase):
train_op_fn=_train_op_fn)
# Assert summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -3441,7 +3441,7 @@ class RegressionHead(test.TestCase):
regularization_losses=regularization_losses)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
prediction_key = prediction_keys.PredictionKeys.PREDICTIONS
@@ -3487,7 +3487,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3523,7 +3523,7 @@ class RegressionHead(test.TestCase):
labels=np.array(((35,), (42,), (45,)), dtype=np.int32))
# Assert loss.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss = sess.run(spec.loss)
# loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6
@@ -3565,7 +3565,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3600,7 +3600,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels_rank_1)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3648,7 +3648,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3679,7 +3679,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
# weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3718,7 +3718,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3750,7 +3750,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
# weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3796,7 +3796,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Evaluate predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3857,7 +3857,7 @@ class RegressionHead(test.TestCase):
self.assertIsNone(spec.train_op)
_assert_no_hooks(self, spec)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Finalize graph and initialize variables.
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3915,7 +3915,7 @@ class RegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.loss.dtype)
self.assertIsNotNone(spec.train_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Finalize graph and initialize variables.
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3955,7 +3955,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3988,7 +3988,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_loss, spec.loss.eval())
@@ -4013,7 +4013,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -4042,7 +4042,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 4e7b00b307..632908415f 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -42,7 +42,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -68,7 +68,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -30)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=128, shuffle=False, num_epochs=2)
features, target = input_fn()
@@ -93,7 +93,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=0)
features, target = input_fn()
@@ -114,7 +114,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -27)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -150,7 +150,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -29)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=3)
features, target = input_fn()
@@ -196,7 +196,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -221,7 +221,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -30)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -240,7 +240,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXAsNonDict(self):
x = list(range(32, 36))
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):
failing_input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -249,7 +249,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXIsEmptyDict(self):
x = {}
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -257,7 +257,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXIsEmptyArray(self):
x = np.array([[], []])
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -268,7 +268,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = None
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features_tensor = input_fn()
@@ -291,7 +291,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithNonBoolShuffle(self):
x = np.arange(32, 36)
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'shuffle must be provided and explicitly '
'set as boolean'):
@@ -303,7 +303,7 @@ class NumpyIoTest(test.TestCase):
x = {'__target_key__': array}
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
input_fn()
@@ -318,7 +318,7 @@ class NumpyIoTest(test.TestCase):
x_mismatch_length = {'a': np.arange(1), 'b': b}
y_longer_length = np.arange(10)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'Length of tensors in x and y is mismatched.'):
failing_input_fn = numpy_io.numpy_input_fn(
@@ -341,7 +341,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features_tensor, targets_tensor = input_fn()
@@ -369,7 +369,7 @@ class NumpyIoTest(test.TestCase):
b = np.arange(32, 36)
x = {'a': a, 'b': b}
y = {}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -379,7 +379,7 @@ class NumpyIoTest(test.TestCase):
b = np.arange(32, 36)
x = {'a': a, 'b': b}
y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, '2 duplicate keys are found in both x and y'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index 6f13bc95d2..9e69fc72dc 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -102,7 +102,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -116,7 +116,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -131,7 +131,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
input_fn = pandas_io.pandas_input_fn(
@@ -147,7 +147,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
input_fn = pandas_io.pandas_input_fn(
@@ -163,7 +163,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -191,7 +191,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -230,7 +230,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -243,7 +243,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -266,7 +266,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -276,7 +276,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -286,7 +286,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -297,7 +297,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 7e5a0c80a7..3758243d7b 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -102,6 +102,49 @@ def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
return input_fn
+def get_multi_inputs_multi_outputs_data():
+ (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=3,
+ random_seed=_RANDOM_SEED)
+ (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+ (m_train, _), (m_test, _) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(8,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+
+ c_train = keras.utils.to_categorical(c_train)
+ c_test = keras.utils.to_categorical(c_test)
+ d_train = keras.utils.to_categorical(d_train)
+ d_test = keras.utils.to_categorical(d_test)
+
+ train_data = {
+ 'input_a': a_train,
+ 'input_b': b_train,
+ 'input_m': m_train,
+ 'output_c': c_train,
+ 'output_d': d_train
+ }
+ test_data = {
+ 'input_a': a_test,
+ 'input_b': b_test,
+ 'input_m': m_test,
+ 'output_c': c_test,
+ 'output_d': d_test
+ }
+
+ return (train_data, test_data)
+
+
def get_resource_for_simple_model(model_type='sequential',
is_evaluate=False,):
if model_type == 'sequential':
@@ -159,20 +202,21 @@ def randomize_io_type(array, name):
def multi_inputs_multi_outputs_model():
- a = keras.layers.Input(shape=(16,), name='input_a')
- b = keras.layers.Input(shape=(16,), name='input_b')
- m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+ input_a = keras.layers.Input(shape=(16,), name='input_a')
+ input_b = keras.layers.Input(shape=(16,), name='input_b')
+ input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
dense = keras.layers.Dense(8, name='dense_1')
- a_2 = dense(a)
+ interm_a = dense(input_a)
# Read m
- m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m)
- s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2])
- b_2 = dense(b)
- merged = keras.layers.concatenate([s_2, b_2], name='merge')
- c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
- d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
- model = keras.models.Model(inputs=[a, b, m], outputs=[c, d])
+ interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+ interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+ interm_b = dense(input_b)
+ merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+ output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+ output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+ model = keras.models.Model(
+ inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -414,51 +458,85 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
]
self.assertAllEqual(est_pred, keras_pred)
- def test_multi_inputs_multi_outputs(self):
- np.random.seed(_RANDOM_SEED)
- (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(16,),
- num_classes=3)
- np.random.seed(_RANDOM_SEED)
- (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(16,),
- num_classes=2)
- np.random.seed(_RANDOM_SEED)
- (input_m_train, _), (input_m_test, _) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(8,),
- num_classes=2)
-
- c_train = keras.utils.to_categorical(c_train)
- c_test = keras.utils.to_categorical(c_test)
- d_train = keras.utils.to_categorical(d_train)
- d_test = keras.utils.to_categorical(d_test)
+ def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
def train_input_fn():
- input_dict = {'input_a': a_train, 'input_b': b_train,
- 'input_m': input_m_train.astype(np.str)}
- output_dict = {'dense_2': c_train, 'dense_3': d_train}
+ input_dict = {
+ 'input_a': train_data['input_a'],
+ 'input_b': train_data['input_b'],
+ 'input_m': train_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': train_data['output_c'],
+ 'dense_3': train_data['output_d']
+ }
return input_dict, output_dict
def eval_input_fn():
- input_dict = {'input_a': a_test, 'input_b': b_test,
- 'input_m': input_m_test.astype(np.str)}
- output_dict = {'dense_2': c_test, 'dense_3': d_test}
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': test_data['output_c'],
+ 'dense_3': test_data['output_d']
+ }
return input_dict, output_dict
+ def pred_input_fn():
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ return input_dict
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn, pred_input_fn)
+
+ def test_multi_inputs_multi_outputs_with_input_fn_as_list(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+ def train_input_fn():
+ input_list = [
+ train_data['input_a'], train_data['input_b'],
+ train_data['input_m'].astype(np.str)
+ ]
+ output_list = [train_data['output_c'], train_data['output_d']]
+ return input_list, output_list
+
+ def eval_input_fn():
+ input_list = [
+ test_data['input_a'], test_data['input_b'],
+ test_data['input_m'].astype(np.str)
+ ]
+ output_list = [test_data['output_c'], test_data['output_d']]
+ return input_list, output_list
+
+ def pred_input_fn():
+ input_list = [
+ test_data['input_a'], test_data['input_b'],
+ test_data['input_m'].astype(np.str)
+ ]
+ return input_list
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn, pred_input_fn)
+
+ def do_test_multi_inputs_multi_outputs_with_input_fn(
+ self, train_input_fn, eval_input_fn, pred_input_fn):
with self.cached_session():
model = multi_inputs_multi_outputs_model()
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
- before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ baseline_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
- after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
- self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+ eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+ est_keras.predict(input_fn=pred_input_fn)
def test_init_from_file(self):
if h5py is None:
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 2246d2f3e9..9984379e9d 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -169,7 +169,8 @@ def _internal_input_layer(features,
weight_collections=None,
trainable=True,
cols_to_vars=None,
- scope=None):
+ scope=None,
+ cols_to_output_tensors=None):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -202,14 +203,17 @@ def _internal_input_layer(features,
trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+ output_tensor = array_ops.reshape(
+ tensor, shape=(batch_size, num_elements))
+ output_tensors.append(output_tensor)
if cols_to_vars is not None:
# Retrieve any variables created (some _DenseColumn's don't create
# variables, in which case an empty list is returned).
cols_to_vars[column] = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES,
scope=variable_scope.get_variable_scope().name)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = output_tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -219,7 +223,8 @@ def input_layer(features,
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
@@ -264,6 +269,9 @@ def input_layer(features,
dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
<tf.Variable 'some_variable:1' shape=(5, 10)]}
If a column creates no variables, its value will be an empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated
+ output `Tensor`s.
Returns:
A `Tensor` which represents input layer of a model. Its shape
@@ -273,8 +281,13 @@ def input_layer(features,
Raises:
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
"""
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
+ return _internal_input_layer(
+ features,
+ feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=cols_to_output_tensors)
# TODO(akshayka): InputLayer should be a subclass of Layer, and it
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 9b482237ab..abb79efa68 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1637,6 +1637,40 @@ class LinearModelTest(test.TestCase):
self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+ def test_fills_cols_to_output_tensors(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ apple_numeric_column = fc.numeric_column('apple_numeric_column')
+ banana_dense_feature = fc.numeric_column('banana_dense_feature')
+ banana_dense_feature_bucketized = fc.bucketized_column(
+ banana_dense_feature, boundaries=[0.])
+ cherry_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'cherry_sparse_feature', hash_bucket_size=5)
+ dragonfruit_embedding_column = fc.embedding_column(
+ cherry_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'apple_numeric_column': [[3.], [4.]],
+ 'banana_dense_feature': [[-1.], [4.]],
+ 'cherry_sparse_feature': [['a'], ['x']],
+ }
+ cols_to_output_tensors = {}
+ all_cols = [
+ apple_numeric_column, banana_dense_feature_bucketized,
+ dragonfruit_embedding_column
+ ]
+ input_layer = fc.input_layer(
+ features, all_cols, cols_to_output_tensors=cols_to_output_tensors)
+
+ # We check the mapping by checking that we have the right keys,
+ # and that the values (output_tensors) were indeed the ones used to
+ # form the input layer.
+ self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
+ input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
+ output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
+ self.assertItemsEqual(input_layer_inputs, output_tensors)
+
def test_dense_collection(self):
price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index 46bda2e621..bc3c81b2a2 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -34,7 +34,7 @@ from tensorflow.python.util import tf_stack
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
-_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
+_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index d312b825d2..1b77548592 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -184,9 +184,14 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (
- r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$")
+ r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
+ def testNewLine(self):
+ newline = "\n\n{{node One}}"
+ interpolated_string = error_interpolation.interpolate(newline, self.graph)
+ self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
+
class InterpolateDeviceSummaryTest(test.TestCase):
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 5eb59141a2..6901715e5d 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase):
load_library.load_file_system_library(file_system_library)
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ee723bacaf..903768a039 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -419,7 +419,7 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([z]):
return x * 2
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
@@ -434,7 +434,7 @@ class FunctionTest(test.TestCase):
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
- with g.as_default(), self.test_session():
+ with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
@@ -448,7 +448,7 @@ class FunctionTest(test.TestCase):
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
@@ -667,7 +667,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 18e7d8aa14..2b4d8e7299 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase):
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
- # with self.test_session() as sess:
+ # with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
@@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase):
imported_r, = importer.import_graph_def(graph_def,
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(imported_r), 10)
def testImportWhileLoopInCond(self):
@@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase):
pred = array_ops.placeholder(dtypes.bool)
out = control_flow_ops.cond(pred, ImportFn,
lambda: constant_op.constant(1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out, {pred: True}), 10)
self.assertEqual(sess.run(out, {pred: False}), 1)
@@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase):
out = control_flow_ops.while_loop(
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out), 10)
def testTypeMismatchInGraphDef(self):
@@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase):
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, t.eval())
def testInvalidInputForReturnOperations(self):
@@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase):
array_ops.stack([c, c], name="pack")
gdef = g.as_graph_def()
- with self.test_session():
+ with self.cached_session():
pack, = importer.import_graph_def(gdef, return_elements=["pack"])
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
@@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual([10], biases_grad.get_shape())
def testLargeGraph(self):
- with self.test_session():
+ with self.cached_session():
# The default message byte limit is 64M. Ours is 2G with a warning at 512.
# Adding a 130M entries float32 tensor should exceed the warning, but not
# the hard limit.
@@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase):
z = TestFunc()
- with self.test_session():
+ with self.cached_session():
z_val = z.eval()
self.assertEqual(z_val, -2.0)
@@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 6e5f7aafac..fc98b91a01 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertEqual(new_output_value, output_value)
def testStrippedOpListNestedFunctions(self):
- with self.test_session():
+ with self.cached_session():
# Square two levels deep
@function.Defun(dtypes.int32)
def f0(x):
@@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase):
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
- with self.test_session():
+ with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase):
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
- with self.test_session():
+ with self.cached_session():
+
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase):
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.stripped_op_list.op.add()
- with self.test_session():
+ with self.cached_session():
meta_graph_def = meta_graph.create_meta_graph_def(
meta_info_def=meta_info_def, graph_def=graph_def,
strip_default_attrs=True)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 9401309c19..343f52fe8f 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
+from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
@@ -5364,6 +5365,7 @@ def enable_eager_execution(config=None,
computational graph).
For example:
+
```python
tf.enable_eager_execution()
@@ -5823,23 +5825,11 @@ def dismantle_graph(graph):
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
after this function runs.
"""
- # pylint: disable=protected-access
- # OrderedDict, constructed on Graph creation, makes a simple reference loop
- # and hides it in an __attribute in some Python versions. We don't need to
- # throw an error if we can't find it, but if we do find it we can break the
- # loop to avoid creating work for the garbage collector.
- graph_operations = graph.get_operations()
- problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
- # pylint: enable=protected-access
- if problematic_cycle:
- try:
- del problematic_cycle[0][:]
- except TypeError:
- # This is probably not one of the problematic Python versions. Continue
- # with the rest of our cleanup.
- pass
+ memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
+
# Now clean up Operation<->Graph reference cycles by clearing all of the
# attributes for the Graph and its ops.
+ graph_operations = graph.get_operations()
for op in graph_operations:
op.__dict__ = {}
graph.__dict__ = {}
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index ced0581402..d59adf3d48 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
- with self.test_session():
+ with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
def testInitialize(self):
- with self.test_session():
+ with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
@@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
pass
def testAddShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
def testUnknownDim(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape)
def testShapeFunctionError(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
def testNegation(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices.eval(), [0, 2])
def testScalarMul(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
def testConvertToTensorNestedArray(self):
- with self.test_session():
+ with self.cached_session():
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
def testShapeTuple(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
@@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(converted, ops.EagerTensor))
def testConvertToTensorNestedTuple(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
def testConvertToTensorNestedTensors(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(values, tensor.eval())
def testConvertToTensorNestedMix(self):
- with self.test_session():
+ with self.cached_session():
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
def testConvertToTensorPreferred(self):
- with self.test_session():
+ with self.cached_session():
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
@@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("bar_2", g.unique_name("bar"))
def testNameAndVariableScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
@@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
- with self.test_session():
+ with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
@@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
return (a, b)
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase):
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(2.0)
sess.run(
a,
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index f2270342b0..f6aef5bc50 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -15,18 +15,20 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen_internal.h"
+#include <float.h>
#include <stdio.h>
+#include <iomanip>
#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -435,7 +437,12 @@ string AttrValueToPython(const string& type, const AttrValue& value,
if (std::isnan(value.f()) || std::isinf(value.f())) {
return strings::StrCat("float('", value.f(), "')");
} else {
- return strings::StrCat(value.f());
+ // Use locale-independent conversion.
+ static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+ std::ostringstream s;
+ s.imbue(std::locale::classic());
+ s << std::setprecision(FLT_DIG) << value.f();
+ return s.str();
}
} else if (type == "bool") {
return value.b() ? "True" : "False";
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index 2bcfbc17df..22423c4f58 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -45,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertEqual(sp.dense_shape.dtype, dtypes.int64)
self.assertEqual(sp.get_shape(), (4, 5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sp.eval()
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
@@ -81,14 +81,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
def test_convert_dense(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
- with self.test_session():
+ with self.cached_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index d6de45fdc4..1d594e4078 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertFalse(c0.op in d.op.control_inputs)
self.assertTrue(c.op in d.op.control_inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
n_out = sess.run([n])
d_out = sess.run([d])
@@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
b = subscribe.subscribe(b,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
@@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(c_sub, c_sub3)
# Expect the three side effect graphs to have been evaluated.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
@@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
self.assertTrue(subscribe._is_subscribed_identity(v1_sub))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize the variables first.
sess.run([v1.initializer])
sess.run([v2.initializer])
@@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(tensor_array_sub, tensor_array.handle)
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([reader])
self.assertEqual(0, len(shared))
@@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
subscribe.subscribe(sparse_add.op.outputs,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([neg])
# All three ops have been processed.
@@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
# Verify that sub(x1) and sub(branch) are not.
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(cond)
self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 395cf43b3f..bdf759f220 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase):
def __array__(self, dtype=None):
return np.asarray(self.array, dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MockArray(np.array([10, 20, 30]))
t = ops.convert_to_tensor(ma)
a = sess.run(t)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4bece9e25e..b7398238f5 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
@@ -413,15 +414,13 @@ def enable_cond_v2(fn):
The wrapped function
"""
- # pylint: disable=protected-access
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops._ENABLE_COND_V2
- control_flow_ops._ENABLE_COND_V2 = True
+ prev_value = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops._ENABLE_COND_V2 = prev_value
- # pylint: enable=protected-access
+ control_flow_ops.ENABLE_COND_V2 = prev_value
return wrapper
@@ -438,7 +437,7 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
@@ -780,7 +779,7 @@ def run_in_graph_and_eager_modes(func=None,
def run_eagerly(self, **kwargs):
if not use_gpu:
- with ops.device("/cpu:0"):
+ with ops.device("/device:CPU:0"):
f(self, **kwargs)
else:
f(self, **kwargs)
@@ -1327,9 +1326,17 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(
- a.shape, b.shape,
- "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ # When the array rank is small, print its contents. Numpy array printing is
+ # implemented using inefficient recursion so prints can cause tests to
+ # time out.
+ if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
+ shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
+ "%s.") % (a.shape, b.shape, b)
+ else:
+ shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
+ b.shape)
+ self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#
@@ -1655,7 +1662,7 @@ class TensorFlowTestCase(googletest.TestCase):
if any of the elements do not fall in the specified range.
"""
target = self._GetNdArray(target)
- if not (np.issubdtype(target.dtype, np.float) or
+ if not (np.issubdtype(target.dtype, np.floating) or
np.issubdtype(target.dtype, np.integer)):
raise AssertionError(
"The value of %s does not have an ordered numeric type, instead it "
@@ -1832,7 +1839,7 @@ class TensorFlowTestCase(googletest.TestCase):
elif use_gpu:
yield sess
else:
- with sess.graph.device("/cpu:0"):
+ with sess.graph.device("/device:CPU:0"):
yield sess
def _create_session(self, graph, config, force_gpu):
@@ -1847,12 +1854,18 @@ class TensorFlowTestCase(googletest.TestCase):
Returns:
A config_pb2.ConfigProto object.
"""
+ # TODO(b/114333779): Enforce allow_soft_placement=False when
+ # use_gpu=False. Currently many tests rely on the fact that any device
+ # will be used even when a specific device is supposed to be used.
+ allow_soft_placement = not force_gpu
if config is None:
config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
+ config.allow_soft_placement = allow_soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
+ elif not allow_soft_placement and config.allow_soft_placement:
+ config_copy = config_pb2.ConfigProto()
+ config_copy.CopyFrom(config)
+ config = config_copy
config.allow_soft_placement = False
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
@@ -2002,3 +2015,42 @@ def set_producer_version(graph, producer_version):
with graph.as_default():
importer.import_graph_def(graph_def)
assert graph.graph_def_versions.producer, producer_version
+
+
+def dismantle_func_graph(func_graph):
+ """Removes reference cycles in `func_graph` FuncGraph.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ the FuncGraph goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+ # Clearing captures using clear() leaves some cycles around.
+ while func_graph.captures:
+ func_graph.captures.popitem()
+ memory.dismantle_ordered_dict(func_graph.captures)
+ ops.dismantle_graph(func_graph)
+
+
+def dismantle_polymorphic_function(func):
+ """Removes reference cycles in PolymorphicFunction `func`.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ PolymorphicFunction goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func: A `PolymorphicFunction` object to destroy. `func` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added
+ cache = func._function_cache # pylint: disable=protected-access
+ for concrete_func in cache.values():
+ dismantle_func_graph(concrete_func.graph)
+ while cache:
+ cache.popitem()
+ memory.dismantle_ordered_dict(cache)
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 290e182a79..b521b1430d 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -337,11 +337,6 @@ py_test(
size = "large",
srcs = ["layers/convolutional_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "manual",
- "noasan", # times out b/63678675
- "notsan",
- ],
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 7768caeaf0..529b07dc12 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -73,7 +73,16 @@ _SESSION = None
# This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
-_GRAPH_LEARNING_PHASES = {}
+_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
+
+
+# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
+# We keep a separate reference to it to make sure it does not get removed from
+# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
+# string because strings are not weakly-referencable.
+class _DummyEagerGraph(object):
+ pass
+_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
# This boolean flag can be set to True to leave variable initialization
# up to the user.
@@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
-_GRAPH_VARIABLES = {}
+_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
-_GRAPH_TF_OPTIMIZERS = {}
+_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
@tf_export('keras.backend.backend')
@@ -359,10 +368,10 @@ def learning_phase():
Learning phase (scalar integer tensor or Python integer).
"""
if context.executing_eagerly():
- if 'eager' not in _GRAPH_LEARNING_PHASES:
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default.
return 0
- return _GRAPH_LEARNING_PHASES['eager']
+ return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES:
@@ -386,7 +395,7 @@ def set_learning_phase(value):
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = value
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@@ -415,7 +424,7 @@ def learning_phase_scope(value):
finally:
# Restore learning phase to initial value.
if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = previous_value
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 266af56611..2f271c4f50 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -279,7 +279,7 @@ class BackendUtilsTest(test.TestCase):
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
def test_function_tf_run_options_with_run_metadata(self):
- with self.test_session():
+ with self.cached_session():
x_placeholder = keras.backend.placeholder(shape=())
y_placeholder = keras.backend.placeholder(shape=())
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 7675a6586f..b6fae19823 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -63,7 +63,7 @@ class KerasCallbacksTest(test.TestCase):
if h5py is None:
return # Skip test if models cannot be saved.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
temp_dir = self.get_temp_dir()
@@ -226,7 +226,7 @@ class KerasCallbacksTest(test.TestCase):
mode='unknown')
def test_EarlyStopping(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -265,7 +265,7 @@ class KerasCallbacksTest(test.TestCase):
verbose=0)
def test_EarlyStopping_reuse(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
patience = 3
data = np.random.random((100, 1))
@@ -287,7 +287,7 @@ class KerasCallbacksTest(test.TestCase):
assert len(hist.epoch) >= patience
def test_EarlyStopping_with_baseline(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
baseline = 0.5
(data, labels), _ = testing_utils.get_test_data(
@@ -321,7 +321,7 @@ class KerasCallbacksTest(test.TestCase):
monitor.on_epoch_end(0, logs={'loss': 0.})
def test_LearningRateScheduler(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -368,7 +368,7 @@ class KerasCallbacksTest(test.TestCase):
model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
def test_ReduceLROnPlateau(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -470,7 +470,7 @@ class KerasCallbacksTest(test.TestCase):
self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
def test_CSVLogger(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
@@ -549,7 +549,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
fp = os.path.join(tmpdir, 'test.csv')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -601,7 +601,7 @@ class KerasCallbacksTest(test.TestCase):
assert 'nan' in values[-1], 'The last epoch was not logged.'
def test_TerminateOnNaN(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -666,7 +666,7 @@ class KerasCallbacksTest(test.TestCase):
i %= max_batch_index
# case: Sequential
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
@@ -743,7 +743,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
filepath = os.path.join(tmpdir, 'logs')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -815,7 +815,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
filepath = os.path.join(tmpdir, 'logs')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -925,7 +925,7 @@ class KerasCallbacksTest(test.TestCase):
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
@@ -969,7 +969,7 @@ class KerasCallbacksTest(test.TestCase):
while True:
yield x, y
- with self.test_session():
+ with self.cached_session():
model = testing_utils.get_small_sequential_mlp(
num_hidden=10, num_classes=10, input_dim=100)
model.compile(
@@ -1011,7 +1011,7 @@ class KerasCallbacksTest(test.TestCase):
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
def test_LambdaCallback(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -1055,7 +1055,7 @@ class KerasCallbacksTest(test.TestCase):
assert not t.is_alive()
def test_TensorBoard_with_ReduceLROnPlateau(self):
- with self.test_session():
+ with self.cached_session():
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
@@ -1194,7 +1194,7 @@ class KerasCallbacksTest(test.TestCase):
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index c1c4970025..b28df75493 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
@@ -212,7 +213,10 @@ def validate_distributed_dataset_inputs(distribution_strategy, x, y):
# validate the input and targets.
x_values_list = validate_per_device_inputs(distribution_strategy, x)
- y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ if y is not None:
+ y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ else:
+ y_values_list = None
# Return the unwrapped values to avoid calling `unwrap` a second time.
return x_values_list, y_values_list
@@ -287,3 +291,74 @@ def configure_and_create_session(distribution_strategy):
session = session_module.Session(config=session_config)
K.set_session(session)
+
+
+def validate_inputs(x, y):
+ """Validate inputs when using DistributionStrategy.
+
+ Args:
+ x: Model Inputs.
+ y: Model Targets.
+
+ Raises:
+ ValueError: if input is not a Dataset or a numpy array.
+ """
+ if isinstance(x, list) or isinstance(y, list):
+ raise ValueError('DistributionStrategy does not support lists of numpy'
+ 'arrays. You must pass a Dataset object or a numpy array '
+ 'as input.')
+
+ if isinstance(x, dict) or isinstance(y, dict):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'dict. You must pass a Dataset object or a numpy array as '
+ 'input.')
+
+ if isinstance(x, iterator_ops.Iterator) or \
+ isinstance(y, iterator_ops.Iterator):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'Iterator. You must pass a Dataset object or a numpy '
+ 'array as input.')
+
+
+def get_input_batch_params(first_x_value, batch_size, current_strategy):
+ """Calculate the number of batches and steps/steps_per_epoch.
+
+ Args:
+ first_x_value: This is the first input numpy array that is passed in as the
+ model input.
+ batch_size: The specified batch_size or the default batch_size of 32.
+ current_strategy: The current DistributionStrategy used to compile the
+ model.
+
+ Returns:
+ The steps or steps_per_epoch argument depending on if a user is
+ calling `fit`, `evaluate` or `predict`.
+
+ Raises:
+ ValueError: If the number of batches or steps evaluates to 0.
+
+ """
+ num_batches = first_x_value.shape[0] // batch_size
+ if not num_batches:
+ raise ValueError('Please specify a batch_size that is smaller than'
+ 'the number of input samples %d.' % first_x_value.shape[0])
+ # TODO(anjalisridhar): TPU currently supports using the num_towers property.
+ # We might want to look into implementing worker_devices. In multi worker
+ # strategy, perhaps num_towers works better?
+ steps = num_batches // current_strategy.num_towers
+ if not steps:
+ # TODO(anjalisridhar): Number of towers in the error message may not convey
+ # what we want to the user. Is there another terminology that we can use
+ # that is consistent across different strategies.
+ raise ValueError('The number of batches %d is smaller than the number '
+ 'of towers %d used for DistributionStrategy. ' %
+ num_batches, current_strategy.num_towers)
+ return steps
+
+
+def get_batch_dimension(iterator):
+ shapes = nest.flatten(iterator.output_shapes)
+ # Take the batch size from the first element, as it should be the same for
+ # all.
+ dims = shapes[0].dims
+ return dims[0] if dims else None
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 10dd70cf23..5ef8d13487 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1576,7 +1576,10 @@ class Network(base_layer.Layer):
def get_json_type(obj):
# If obj is any numpy type
if type(obj).__module__ == np.__name__:
- return obj.item()
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return obj.item()
# If obj is a python 'type'
if type(obj).__name__ == type.__name__:
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index a2eed7cb46..a2f31fda8f 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -248,7 +248,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
loss = convert_custom_objects(training_config['loss'])
metrics = convert_custom_objects(training_config['metrics'])
weighted_metrics = convert_custom_objects(
- training_config['weighted_metrics'])
+ training_config.get('weighted_metrics', None))
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 441f3f4948..148dd23be7 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -48,7 +48,7 @@ except ImportError:
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
def test_weight_loading(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(2,))
x = keras.layers.Dense(3)(a)
b = keras.layers.Dense(1)(x)
@@ -208,7 +208,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
}))
def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
self, layer_class, layer_args):
- with self.test_session():
+ with self.cached_session():
layer = layer_class(**layer_args)
layer.build(input_shape=layer_args.get('input_shape'))
weights1 = layer.get_weights()
@@ -232,7 +232,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
model.add(keras.layers.Dense(num_classes))
@@ -261,7 +261,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
num_hidden = 5
input_dim = 3
num_classes = 2
- with self.test_session():
+ with self.cached_session():
ref_model = keras.models.Sequential()
ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
name='d1'))
@@ -298,7 +298,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
num_hidden = 5
input_dim = 3
num_classes = 2
- with self.test_session():
+ with self.cached_session():
ref_model = keras.models.Sequential()
ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
name='d1'))
@@ -333,7 +333,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -378,7 +378,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -402,7 +402,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
# test with custom optimizer, loss
class CustomOp(keras.optimizers.RMSprop):
@@ -438,7 +438,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -474,7 +474,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -490,7 +490,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -508,7 +508,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -522,7 +522,7 @@ class TestWholeModelSaving(test.TestCase):
os.remove(fname)
def test_saving_lambda_numpy_array_arguments(self):
- with self.test_session():
+ with self.cached_session():
if h5py is None:
self.skipTest('h5py required to run this test')
@@ -548,7 +548,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
# This layer name will make the `layers_name` HDF5 attribute blow
# out of proportion. Note that it fits into the internal HDF5
# attribute memory limit on its own but because h5py converts
@@ -589,7 +589,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
x = keras.Input(shape=(2,), name='nested_model_input')
f = x
for i in range(4):
@@ -634,7 +634,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
outputs = keras.layers.Dense(3)(x)
@@ -703,7 +703,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_tensorflow_format_overwrite(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
model = SubclassedModel()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
@@ -760,7 +760,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.assertEqual(len(graph.get_operations()), op_count)
def _weight_loading_test_template(self, make_model_fn):
- with self.test_session():
+ with self.cached_session():
model = make_model_fn()
model.compile(
loss='mse',
@@ -822,7 +822,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def _new_layer_weight_loading_test_template(
self, first_model_fn, second_model_fn, restore_init_fn):
- with self.test_session() as session:
+ with self.cached_session() as session:
model = first_model_fn()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 28af8d61bc..9d615c9b0c 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -132,7 +132,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@parameterized.parameters((True,), (False,))
def test_training_and_eval_methods_on_symbolic_tensors(self, deferred):
- with self.test_session():
+ with self.cached_session():
def get_model():
if deferred:
@@ -222,7 +222,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
assert model.updates
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 079c8dae71..061db8ee34 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -342,7 +342,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual(model.non_trainable_weights, weights)
def test_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -458,7 +458,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(dense.get_output_mask_at(1), None)
def test_multi_input_layer(self):
- with self.test_session():
+ with self.cached_session():
# test multi-input layer
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -530,7 +530,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
def test_recursion(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -591,7 +591,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)])
def test_multi_input_multi_output_recursion(self):
- with self.test_session():
+ with self.cached_session():
# test multi-input multi-output
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -816,7 +816,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(loss, 4.)
def test_layer_sharing_at_heterogenous_depth(self):
- with self.test_session():
+ with self.cached_session():
x_val = np.random.random((10, 5))
x = input_layer_lib.Input(shape=(5,))
@@ -837,7 +837,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertAllClose(output_val, output_val_2, atol=1e-6)
def test_layer_sharing_at_heterogenous_depth_with_concat(self):
- with self.test_session():
+ with self.cached_session():
input_shape = (16, 9, 3)
input_layer = input_layer_lib.Input(shape=input_shape)
@@ -864,7 +864,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertAllClose(output_val, output_val_2, atol=1e-6)
def test_explicit_training_argument(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(2,))
b = keras.layers.Dropout(0.5)(a)
base_model = keras.models.Model(a, b)
@@ -887,7 +887,8 @@ class TopologyConstructionTest(test.TestCase):
def test_multi_output_model_with_none_masking(self):
- with self.test_session():
+ with self.cached_session():
+
def func(x):
return [x * 0.2, x * 0.3]
@@ -912,6 +913,23 @@ class TopologyConstructionTest(test.TestCase):
assert out.shape == (4, 3, 2, 1)
self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
+ def test_constant_initializer_with_numpy(self):
+
+ with self.test_session():
+ initializer = keras.initializers.Constant(np.ones((3, 2)))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,),
+ kernel_initializer=initializer))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ json_str = model.to_json()
+ keras.models.model_from_json(json_str)
+
+ if yaml is not None:
+ yaml_str = model.to_yaml()
+ keras.models.model_from_yaml(yaml_str)
+
class DeferredModeTest(test.TestCase):
@@ -1169,7 +1187,7 @@ class GraphUtilsTest(test.TestCase):
def testGetReachableFromInputs(self):
- with self.test_session():
+ with self.cached_session():
pl_1 = array_ops.placeholder(shape=None, dtype='float32')
pl_2 = array_ops.placeholder(shape=None, dtype='float32')
pl_3 = array_ops.placeholder(shape=None, dtype='float32')
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index d224dfffdd..fed07c4120 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import weakref
import numpy as np
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -39,6 +41,7 @@ from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
+from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
@@ -206,8 +209,27 @@ class Model(Network):
for metric in metrics:
metric_fn = training_utils.get_metric_function(
metric, output_shape=output_shape, loss_fn=loss_fn)
- metric_name = self._get_metric_name(
- metric, output_index, weighted=weights is not None)
+
+ if (context.executing_eagerly() and y_true is not None and
+ y_pred is not None):
+ # In eager mode, when executing metric_fn during training, we do not
+ # need to generate unique metric name and add it to the model
+ # as we have done that during compile already.
+ prefix = 'weighted_' if weights is not None else ''
+ suffix = metric_fn.name if hasattr(metric_fn,
+ 'name') else metric_fn.__name__
+ metric_name = prefix + suffix
+ else:
+ # Get metric name that is to be added to the model.
+ metric_name = self._get_metric_name(
+ metric, output_index, weighted=weights is not None)
+ # Keep track of metric name.
+ self.metrics_names.append(metric_name)
+
+ # Keep track of stateful metric attributes (name and metric function).
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
+ self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
with K.name_scope(metric_name):
# If both outputs and targets are available, call the metric function.
@@ -247,16 +269,10 @@ class Model(Network):
self.metrics_tensors.append(metric_result)
metric_results.append(metric_result)
- # Keep track of metric name.
- self.metrics_names.append(metric_name)
-
- # Keep track of stateful metric attributes (name and metric function).
- if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
- self.stateful_metric_names.append(metric_name)
- self.stateful_metric_functions.append(metric_fn)
- if not context.executing_eagerly():
- # Keep track of updates created by stateful metrics.
- self.metrics_updates += metric_fn.updates
+ if (isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful and
+ not context.executing_eagerly()):
+ # Keep track of updates created by stateful metrics.
+ self.metrics_updates += metric_fn.updates
return metric_results
def _handle_metrics(self,
@@ -754,9 +770,8 @@ class Model(Network):
the model.
Args:
- x: Input data. A `tf.data` dataset.
- y: Since `x` is a dataset, `y` should not be specified
- (since targets will be obtained from the iterator).
+ x: Input data. A numpy array or `tf.data` dataset.
+ y: Target data. A numpy array or None if x is a `tf.data` dataset.
sample_weight: An optional sample-weight array passed by the user to
weight the importance of each sample in `x`.
class_weight: An optional class-weight array by the user to
@@ -786,12 +801,51 @@ class Model(Network):
raise NotImplementedError('`class_weight` is currently not supported '
'when using DistributionStrategy.')
+ # Validates `steps` argument right at the beginning since we use it to
+ # construct the dataset object.
+ # TODO(anjalisridhar): This may not be a valid error since we now accept
+ # numpy array inputs. We still want to assert that we have a populated steps
+ # parameter.
+ if check_steps:
+ if steps is None:
+ raise ValueError('When using DistributionStrategy, '
+ 'you should specify the `{steps_name}` argument.'
+ .format(steps_name=steps_name))
+
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray):
+ x_shape = first_x_value.shape
+ x_dtype = first_x_value.dtype
+ if batch_size is None:
+ batch_size = x_shape[0] // steps
+ if y is not None:
+ first_y_value = nest.flatten(y)[0]
+ x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
+ output_types=(x_dtype, first_y_value.dtype),
+ output_shapes=(x_shape[1:],
+ first_y_value.shape[1:]))
+ # TODO(anjalisridhar): What should the buffer size be?
+ x = x.shuffle(10000)
+ x = x.repeat()
+ x = x.batch(batch_size)
+ y = None
+ else:
+ # This case is for the predict call where the dataset only contains
+ # inputs and no targets i.e it does not return a tuple.
+ # TODO(anjalisridhar): Raise an error if we are not able to process
+ # all the predict samples. This can happen if the number of batches is
+ # not evenly divisible by the number of worker devices.
+ x = Dataset.from_generator(lambda x=x: x,
+ output_types=x_dtype,
+ output_shapes=x_shape[1:])
+ x = x.repeat()
+ x = x.batch(batch_size)
+
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
# multiple devices.
- if not isinstance(x, dataset_ops.Dataset):
- raise ValueError('When using DistributionStrategy, model inputs should be'
- ' Dataset instances; found instead %s.' % type(x))
+ assert isinstance(x, dataset_ops.Dataset)
+
# TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
# function which returns a Dataset. Currently distribute_dataset() only
# accepts a function that returns a Dataset. Once we add support for being
@@ -799,12 +853,6 @@ class Model(Network):
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
K.get_session().run(iterator.initializer)
- # Validates `steps` argument based on x's type.
- if check_steps:
- if steps is None:
- raise ValueError('When using a Dataset instance as input to a model, '
- 'you should specify the `{steps_name}` argument.'
- .format(steps_name=steps_name))
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
@@ -1304,6 +1352,9 @@ class Model(Network):
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
@@ -1316,19 +1367,23 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
- of either (inputs, targets) or (inputs, targets, sample_weights).
+ of either `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
+ - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
+ or `(inputs, targets, sample weights)`.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset or dataset
- iterator, `y` should not be specified
- (since targets will be obtained from the iterator).
+ tensor targets, or inversely). If `x` is a dataset, dataset
+ iterator, generator, or `keras.utils.Sequence` instance, `y` should
+ not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1350,7 +1405,8 @@ class Model(Network):
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
- not supported when `x` is a dataset or a dataset iterator.
+ not supported when `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
@@ -1381,8 +1437,9 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator, instead
- provide the sample_weights as the third element of `x`.
+ supported when `x` is a dataset, dataset iterator, generator, or
+ `keras.utils.Sequence` instance, instead provide the sample_weights
+ as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1396,6 +1453,20 @@ class Model(Network):
validation_steps: Only relevant if `steps_per_epoch`
is specified. Total number of steps (batches of samples)
to validate before stopping.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up
+ when using process-based threading. If unspecified, `workers`
+ will default to 1. If 0, will execute the generator on the main
+ thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
**kwargs: Used for backwards compatibility.
Returns:
@@ -1412,6 +1483,23 @@ class Model(Network):
# TODO(fchollet): this method may be creating reference cycles, which would
# lead to accumulating garbage in memory when called in a loop. Investigate.
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.fit_generator(
+ x,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_data=validation_data,
+ validation_steps=validation_steps,
+ class_weight=class_weight,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch)
+
# Backwards compatibility
if batch_size is None and steps_per_epoch is None:
batch_size = 32
@@ -1428,6 +1516,13 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
+ distributed_training_utils.validate_inputs(x, y)
+
+ first_x_value = nest.flatten(x)[0]
+ if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
+ steps_per_epoch = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1462,6 +1557,13 @@ class Model(Network):
'However we received `validation_data=%s`' % validation_data)
# Validate and standardize validation data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(val_x, val_y)
+ first_valx_value = nest.flatten(val_x)[0]
+ if not validation_steps and isinstance(first_valx_value, np.ndarray):
+ validation_steps = distributed_training_utils.get_input_batch_params(
+ first_valx_value, batch_size, self._distribution_strategy)
+
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x,
val_y,
@@ -1540,7 +1642,10 @@ class Model(Network):
batch_size=None,
verbose=1,
sample_weight=None,
- steps=None):
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Returns the loss value & metrics values for the model in test mode.
Computation is done in batches.
@@ -1554,18 +1659,21 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely).
- If `x` is a dataset or a dataset iterator, `y` should not be specified
- (since targets will be obtained from the iterator/dataset).
+ If `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance, `y` should not be specified (since
+ targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: 0 or 1. Verbosity mode.
0 = silent, 1 = progress bar.
sample_weight: Optional Numpy array of weights for
@@ -1579,11 +1687,25 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead pass
+ sample weights as the third element of `x`.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1594,11 +1716,28 @@ class Model(Network):
Raises:
ValueError: in case of invalid arguments.
"""
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.evaluate_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
# Validate and standardize user data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(x, y)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1633,7 +1772,14 @@ class Model(Network):
verbose=verbose,
steps=steps)
- def predict(self, x, batch_size=None, verbose=0, steps=None):
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Generates output predictions for the input samples.
Computation is done in batches.
@@ -1645,16 +1791,32 @@ class Model(Network):
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, dataset, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: Verbosity mode, 0 or 1.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
+
Returns:
Numpy array(s) of predictions.
@@ -1665,18 +1827,35 @@ class Model(Network):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
+ if data_utils.is_generator_or_sequence(x):
+ return self.predict_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
- # Turn off prefetching since this is currently not deterministic. Once
- # b/112498930 is fixed we can turn it back on.
- # `_prefetch_on_device` is currently a property of only `MirroredStrategy`.
- if (self._distribution_strategy and
- hasattr(self._distribution_strategy, '_prefetch_on_device')):
- self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+ if self._distribution_strategy:
+ # Turn off prefetching since this is currently not deterministic. Once
+ # b/112498930 is fixed we can turn it back on.
+ # `_prefetch_on_device` is currently a property of only
+ # `MirroredStrategy`.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+ distributed_training_utils.validate_inputs(x, None)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
# Validate and standardize user data.
+ # TODO(anjalisridhar): We don't pass batch_size here for some reason. This
+ # means that we end up calculating it twice which we should avoid.
x, _, _ = self._standardize_user_data(
x, check_steps=True, steps_name='steps', steps=steps)
@@ -2008,7 +2187,7 @@ class Model(Network):
Arguments:
generator: Generator yielding tuples (inputs, targets)
or (inputs, targets, sample_weights)
- or an instance of Sequence (keras.utils.Sequence)
+ or an instance of `keras.utils.Sequence`
object in order to avoid duplicate data
when using multiprocessing.
steps: Total number of steps (batches of samples)
@@ -2072,9 +2251,8 @@ class Model(Network):
Arguments:
generator: Generator yielding batches of input samples
- or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data
- when using multiprocessing.
+ or an instance of `keras.utils.Sequence` object in order to
+ avoid duplicate data when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 939732cd67..53291c3956 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -27,10 +27,14 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
+# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
+
+
def fit_loop(
model,
iterator,
@@ -41,13 +45,13 @@ def fit_loop(
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None):
- """fit function when using DistributionStrategy for training.
+ """Fit loop for training with DistributionStrategy.
Arguments:
model: Keras Model instance.
iterator: Iterator for input data.
epochs: Number of times to iterate over the data
- verbose: Verbosity mode, 0, 1 or 2
+ verbose: Integer, Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
val_iterator: Iterator for validation data.
initial_epoch: Epoch at which to start training
@@ -73,8 +77,8 @@ def fit_loop(
model, iterator, epochs, verbose, callbacks, initial_epoch,
steps_per_epoch)
- clone_model_on_towers(
- model, current_strategy, make_callback_model=True)
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy, make_callback_model=True)
def _per_device_train_function(model):
model._make_train_function()
@@ -206,13 +210,13 @@ def _experimental_fit_loop(
callbacks=None,
initial_epoch=0,
steps_per_epoch=None):
- """fit function when using TPU DistributionStrategy for training.
+ """Fit loop for training with TPU DistributionStrategy.
Arguments:
model: Keras Model instance.
iterator: Iterator that returns inputs and targets
epochs: Number of times to iterate over the data
- verbose: Verbosity mode, 0, 1 or 2
+ verbose: Integer, Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
@@ -244,7 +248,9 @@ def _experimental_fit_loop(
def step_fn(ctx, inputs, targets):
"""Clones the model and calls make_train_function."""
- # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
clone_model_on_towers(
model,
current_strategy,
@@ -258,19 +264,22 @@ def _experimental_fit_loop(
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
current_strategy, grouped_inputs, grouped_outputs,
- grouped_updates, grouped_session_args, with_loss_tensor=True)
+ grouped_updates, grouped_session_args)
combined_fn = K.Function(
all_inputs, all_outputs,
updates=all_updates,
name='distributed_train_function',
**all_session_args)
- # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
- # something else for different outputs.
out_labels = model.metrics_names or []
for label, output in zip(out_labels, combined_fn.outputs):
- ctx.set_last_step_output(label, output,
- aggregation=distribute_lib.get_loss_reduction())
+ if label == 'loss':
+ aggregation = distribute_lib.get_loss_reduction()
+ else:
+ # We aggregate all other metrics using mean for now. This is temporary
+ # workaround until new metrics are in place.
+ aggregation = variable_scope.VariableAggregation.MEAN
+ ctx.set_last_step_output(label, output, aggregation)
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
# feed_dict, session kwargs, run options, run_metadata for now. These should
@@ -324,10 +333,9 @@ def _experimental_fit_loop(
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
- # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks
- # are fixed as we need to replace size with a combination of steps_per_run
+ # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
# and batch_size
- batch_logs = {'batch': step_index}
+ batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
_, outputs = K.get_session().run([train_op, output_tensors])
@@ -360,12 +368,12 @@ def _experimental_fit_loop(
def test_loop(model, iterator, verbose=0, steps=None):
- """evaluate method to validate a model that uses DistributionStrategy.
+ """Test loop for evaluating with DistributionStrategy.
Arguments:
model: Keras Model instance.
iterator: Iterator for input data.
- verbose: verbosity mode.
+ verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring predictions finished.
Ignored with the default value of `None`.
@@ -374,11 +382,16 @@ def test_loop(model, iterator, verbose=0, steps=None):
Scalar loss (if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
+ the display labels for the outputs.
"""
current_strategy = model._distribution_strategy
- clone_model_on_towers(model, current_strategy)
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_test_loop(model, iterator, verbose, steps)
+
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy)
def _per_device_test_function(model):
model._make_test_function()
@@ -429,25 +442,136 @@ def test_loop(model, iterator, verbose=0, steps=None):
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- for _ in enumerate(batch_outs):
- outs.append(0.)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ outs[i] /= steps
+
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def _experimental_test_loop(model, iterator, verbose=0, steps=None):
+ """Test loop for evaluating with TPU DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
+ steps: Total number of steps (batches of samples)
+ before declaring predictions finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the outputs.
+ """
+ current_strategy = model._distribution_strategy
+ K.get_session().run(current_strategy.initialize())
+
+ def _per_device_test_function(model):
+ model._make_test_function()
+ return (model.test_function.inputs,
+ model.test_function.outputs,
+ model.test_function.updates_op,
+ model.test_function.session_kwargs)
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(0)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_test_function."""
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=False,
+ inputs=inputs,
+ targets=targets)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_test_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args)
+
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
+
+ for label, output in zip(model.metrics_names, combined_fn.outputs):
+ if label == 'loss':
+ aggregation = distribute_lib.get_loss_reduction()
else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose == 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- outs[i] /= steps
+ # We aggregate all other metrics using mean for now. This is temporary
+ # workaround until new metrics are in place.
+ aggregation = variable_scope.VariableAggregation.MEAN
+ ctx.set_last_step_output(label, output, aggregation)
+
+ return combined_fn.updates_op
+
+ # Add initial dummy values for loss and other metric tensors.
+ initial_loop_values = {}
+ initial_loop_values['loss'] = constant_op.constant(1e7)
+ for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag): Use steps_per_run when we use new metrics as they will
+ # allow handling metric computation at each step using variables.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=1,
+ initial_loop_values=initial_loop_values)
+
+ test_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps is not None
+ outs = [0.] * len(model.metrics_names)
+ for step in range(steps):
+ _, batch_outs = K.get_session().run([test_op, output_tensors])
+ for i, label in enumerate(model.metrics_names):
+ outs[i] += batch_outs[label]
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ outs[i] /= (steps)
+
+ K.get_session().run(current_strategy.finalize())
if len(outs) == 1:
return outs[0]
@@ -455,12 +579,12 @@ def test_loop(model, iterator, verbose=0, steps=None):
def predict_loop(model, iterator, verbose=0, steps=None):
- """Abstract method to loop over some data in batches.
+ """Predict loop for predicting with DistributionStrategy.
Arguments:
model: Keras Model instance.
iterator: Iterator for input data.
- verbose: verbosity mode.
+ verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring `_predict_loop` finished.
Ignored with the default value of `None`.
@@ -472,7 +596,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
"""
current_strategy = model._distribution_strategy
- clone_model_on_towers(model, current_strategy)
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_predict_loop(model, iterator, verbose, steps)
+
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy)
def _per_device_predict_function(model):
model._make_predict_function()
@@ -528,9 +657,11 @@ def predict_loop(model, iterator, verbose=0, steps=None):
if step == 0:
for _ in batch_outs:
unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
for i, batch_out in enumerate(batch_outs):
unconcatenated_outs[i].append(batch_out)
- if verbose == 1:
+ if verbose >= 1:
progbar.update(step + 1)
if len(unconcatenated_outs) == 1:
return np.concatenate(unconcatenated_outs[0], axis=0)
@@ -540,6 +671,122 @@ def predict_loop(model, iterator, verbose=0, steps=None):
]
+def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
+ """Predict loop for predicting with TPU DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
+ steps: Total number of steps (batches of samples)
+ before declaring `_predict_loop` finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions
+ (if the model has multiple outputs).
+ """
+ current_strategy = model._distribution_strategy
+ K.get_session().run(current_strategy.initialize())
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(0)
+
+ def _per_device_predict_function(model):
+ model._make_predict_function()
+ return (model.predict_function.inputs,
+ model.predict_function.outputs,
+ model.predict_function.updates_op,
+ model.predict_function.session_kwargs)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_predict_function."""
+
+ # TODO(anjalisridhar): Support predict input correctly as it will not
+ # contain targets, only inputs.
+ del targets
+
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=False,
+ inputs=inputs)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_predict_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args)
+
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
+
+ for label, output in zip(model.output_names, combined_fn.outputs):
+ ctx.set_last_step_output(label, output)
+
+ return combined_fn.updates_op
+
+ # Add initial dummy values for outputs.
+ initial_loop_values = {}
+ batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
+ for name, tensor in zip(model.output_names, model.outputs):
+ # TODO(priyag): This is a workaround as we do not know the batch dimension
+ # of the model's output at this point.
+ tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=1,
+ initial_loop_values=initial_loop_values)
+
+ predict_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps is not None
+ # Since we do not know how many samples we will see, we cannot pre-allocate
+ # the returned Numpy arrays. Instead, we store one array per batch seen
+ # and concatenate them upon returning.
+ unconcatenated_outs = [[] for _ in model.outputs]
+ for step in range(steps):
+ _, batch_outs = K.get_session().run([predict_op, output_tensors])
+ # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
+ for i, label in enumerate(model.output_names):
+ unconcatenated_outs[i].extend(batch_outs[label])
+ if verbose >= 1:
+ progbar.update(step + 1)
+
+ K.get_session().run(current_strategy.finalize())
+
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
+
+
def _clone_and_build_model(model, inputs=None, targets=None):
"""Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency
@@ -572,13 +819,12 @@ def _clone_and_build_model(model, inputs=None, targets=None):
def clone_model_on_towers(
model, strategy, make_callback_model=False, inputs=None, targets=None):
- """Create a cloned model on each tower, unless already created."""
- if not model._grouped_model:
- with strategy.scope():
- model._grouped_model = strategy.call_for_each_tower(
- _clone_and_build_model, model, inputs, targets)
- if make_callback_model:
- model._make_callback_model()
+ """Create a cloned model on each tower."""
+ with strategy.scope():
+ model._grouped_model = strategy.call_for_each_tower(
+ _clone_and_build_model, model, inputs, targets)
+ if make_callback_model:
+ model._make_callback_model()
def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
@@ -615,14 +861,12 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
def _get_input_from_iterator(iterator, model):
"""Get elements from the iterator and verify the input shape and type."""
next_element = iterator.get_next()
- # TODO(anjalisridhar): Support predict input correctly as it will not contain
- # targets, only inputs.
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide model inputs as a list or tuple of 2 '
- 'elements: input and target pair. '
- 'Received %s' % next_element)
-
- x, y = next_element
+
+ if isinstance(next_element, tuple):
+ x, y = next_element
+ else:
+ x = next_element
+ y = None
# Validate that all the elements in x and y are of the same type and shape.
# We can then pass the first element of x and y to `_standardize_weights`
# below and be confident of the output.
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 1d0d113e40..30be4131a4 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -366,7 +366,7 @@ class TrainingTest(test.TestCase):
if scipy_sparse is None:
return
- with self.test_session():
+ with self.cached_session():
test_inputs = [
scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)
]
@@ -389,7 +389,7 @@ class TrainingTest(test.TestCase):
model.evaluate(test_inputs, test_outputs, batch_size=2)
def test_compile_with_sparse_placeholders(self):
- with self.test_session():
+ with self.cached_session():
input_layer = keras.layers.Input(shape=(10,), sparse=True)
weights = variables_lib.Variable(
np.ones((10, 1)).astype(np.float32), name='weights')
@@ -405,7 +405,7 @@ class TrainingTest(test.TestCase):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,))
b = layer(a)
@@ -441,7 +441,7 @@ class TrainingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_compile_warning_for_loss_missing_output(self):
- with self.test_session():
+ with self.cached_session():
inp = keras.layers.Input(shape=(16,), name='input_a')
out_1 = keras.layers.Dense(8, name='dense_1')(inp)
out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1)
@@ -654,7 +654,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -741,7 +741,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -810,7 +810,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -854,7 +854,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_graph_sequential(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
@@ -868,7 +868,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_deferred_sequential(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0))
@@ -882,7 +882,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_functional(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
inputs = keras.layers.Input((2, 1))
outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -912,7 +912,7 @@ class LossMaskingTest(test.TestCase):
def compute_output_shape(self, input_shape):
return input_shape
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
inputs = keras.layers.Input((3,))
masked = keras.layers.Masking(mask_value=0)(inputs)
@@ -924,7 +924,7 @@ class LossMaskingTest(test.TestCase):
model.train_on_batch(x, y)
def test_loss_masking(self):
- with self.test_session():
+ with self.cached_session():
weighted_loss = weighted_masked_objective(keras.losses.get('mae'))
shape = (3, 4, 2)
x = np.arange(24).reshape(shape)
@@ -945,12 +945,12 @@ class LossMaskingTest(test.TestCase):
class LearningPhaseTest(test.TestCase):
def test_empty_model_no_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
self.assertFalse(model.uses_learning_phase)
def test_dropout_has_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_dim=3))
model.add(keras.layers.Dropout(0.5))
@@ -961,7 +961,7 @@ class LearningPhaseTest(test.TestCase):
class TestDynamicTrainability(test.TestCase):
def test_trainable_warning(self):
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
y = np.random.random((5, 2))
@@ -974,7 +974,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertRaises(Warning)
def test_trainable_argument(self):
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
y = np.random.random((5, 2))
@@ -997,7 +997,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertAllClose(out, out_2)
def test_layer_trainability_switch(self):
- with self.test_session():
+ with self.cached_session():
# with constructor argument, in Sequential
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, trainable=False, input_dim=1))
@@ -1027,7 +1027,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertListEqual(model.trainable_weights, [])
def test_model_trainability_switch(self):
- with self.test_session():
+ with self.cached_session():
# a non-trainable model has no trainable weights
x = keras.layers.Input(shape=(1,))
y = keras.layers.Dense(2)(x)
@@ -1042,7 +1042,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertListEqual(model.trainable_weights, [])
def test_nested_model_trainability(self):
- with self.test_session():
+ with self.cached_session():
# a Sequential inside a Model
inner_model = keras.models.Sequential()
inner_model.add(keras.layers.Dense(2, input_dim=1))
@@ -1121,7 +1121,7 @@ class TestGeneratorMethods(test.TestCase):
y = arr_labels[start: end]
yield x, y
- with self.test_session():
+ with self.cached_session():
x = keras.Input((2,))
y = keras.layers.Dense(1)(x)
fn_model = keras.models.Model(x, y)
@@ -1207,7 +1207,7 @@ class TestGeneratorMethods(test.TestCase):
w = arr_sample_weights[start: end]
yield x, y, w
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(
@@ -1244,7 +1244,7 @@ class TestGeneratorMethods(test.TestCase):
while 1:
yield 0
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
@@ -1302,7 +1302,7 @@ class TestGeneratorMethods(test.TestCase):
w = arr_sample_weights[start: end]
yield x, y, w
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
@@ -1322,6 +1322,57 @@ class TestGeneratorMethods(test.TestCase):
workers=0,
use_multiprocessing=False)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_generator_input_to_fit_eval_predict(self):
+ val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ def custom_generator():
+ while True:
+ yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ inputs = keras.layers.Input(shape=(10,))
+ x = keras.layers.Dense(10, activation='relu')(inputs)
+ outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+ model = keras.Model(inputs, outputs)
+
+ model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+ model.fit(
+ custom_generator(),
+ steps_per_epoch=2,
+ validation_data=val_data,
+ epochs=2)
+ model.evaluate(custom_generator(), steps=2)
+ model.predict(custom_generator(), steps=2)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequence_input_to_fit_eval_predict(self):
+ val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ class CustomSequence(keras.utils.Sequence):
+
+ def __getitem__(self, idx):
+ return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ def __len__(self):
+ return 2
+
+ inputs = keras.layers.Input(shape=(10,))
+ x = keras.layers.Dense(10, activation='relu')(inputs)
+ outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+ model = keras.Model(inputs, outputs)
+
+ model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+ model.fit(CustomSequence(), validation_data=val_data, epochs=2)
+ model.evaluate(CustomSequence())
+ model.predict(CustomSequence())
+
+ with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
+ model.fit(CustomSequence(), y=np.ones([10, 1]))
+
+ with self.assertRaisesRegexp(ValueError,
+ '`sample_weight` argument is not supported'):
+ model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
+
class TestTrainingUtils(test.TestCase):
@@ -1360,7 +1411,7 @@ class TestTrainingUtils(test.TestCase):
class TestTrainingWithDataTensors(test.TestCase):
def test_training_and_eval_methods_on_symbolic_tensors_single_io(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -1400,7 +1451,7 @@ class TestTrainingWithDataTensors(test.TestCase):
validation_data=(inputs, targets), validation_steps=2)
def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -1501,7 +1552,7 @@ class TestTrainingWithDataTensors(test.TestCase):
by only passing them data for the placeholder inputs
in the model.
"""
- with self.test_session():
+ with self.cached_session():
input_a_np = np.random.random((10, 3))
input_b_np = np.random.random((10, 3))
@@ -1632,7 +1683,7 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out.shape, (10 * 3, 4))
def test_model_with_partial_loss(self):
- with self.test_session():
+ with self.cached_session():
a = keras.Input(shape=(3,), name='input_a')
a_2 = keras.layers.Dense(4, name='dense_1')(a)
dp = keras.layers.Dropout(0.5, name='dropout')
@@ -1673,7 +1724,7 @@ class TestTrainingWithDataTensors(test.TestCase):
_ = model.evaluate(input_a_np, [output_a_np])
def test_model_with_external_loss(self):
- with self.test_session():
+ with self.cached_session():
# None loss, only regularization loss.
a = keras.Input(shape=(3,), name='input_a')
a_2 = keras.layers.Dense(4, name='dense_1',
@@ -1803,7 +1854,7 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out[1].shape, (10 * 3, 4))
def test_target_tensors(self):
- with self.test_session():
+ with self.cached_session():
# single-output, as list
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,), name='dense'))
@@ -1864,7 +1915,7 @@ class TestTrainingWithDataTensors(test.TestCase):
sample_weight={'dense_a': np.random.random((10,))})
def test_model_custom_target_tensors(self):
- with self.test_session():
+ with self.cached_session():
a = keras.Input(shape=(3,), name='input_a')
b = keras.Input(shape=(3,), name='input_b')
@@ -2154,7 +2205,7 @@ class TestTrainingWithDataset(test.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
def test_dataset_input_shape_validation(self):
- with self.test_session():
+ with self.cached_session():
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
@@ -2205,7 +2256,26 @@ class TestTrainingWithMetrics(test.TestCase):
'dense_binary_accuracy', 'dropout_mean_squared_error',
'dropout_binary_accuracy'
]
+ reference_stateful_metric_names = [
+ 'dense_binary_accuracy', 'dropout_binary_accuracy'
+ ]
+ self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
+
+ # Verify that model metric names are not altered during training.
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5)
self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
@@ -2333,7 +2403,7 @@ class TestTrainingWithMetrics(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_masking(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 898e9223cb..8e9fab81d6 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -797,6 +797,18 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: x=%s, validation_split=%f' % (x, validation_split))
+def check_generator_arguments(y=None, sample_weight=None):
+ """Validates arguments passed when using a generator."""
+ if y is not None:
+ raise ValueError('`y` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass targets'
+ ' as the second element of the generator.')
+ if sample_weight is not None:
+ raise ValueError('`sample_weight` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass sample'
+ ' weights as the third element of the generator.')
+
+
def check_steps_argument(input_data, steps, steps_name):
"""Validates `steps` argument based on input data's type.
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index a57ac121ed..d00def07bb 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -64,7 +64,7 @@ class Conv(Layer):
specifying the stride length of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
+ padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
@@ -126,6 +126,10 @@ class Conv(Layer):
kernel_size, rank, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
+ if (self.padding == 'causal' and not isinstance(self,
+ (Conv1D, SeparableConv1D))):
+ raise ValueError('Causal padding is only supported for `Conv1D`'
+ 'and ``SeparableConv1D`.')
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(
dilation_rate, rank, 'dilation_rate')
@@ -172,12 +176,16 @@ class Conv(Layer):
self.bias = None
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
+ if self.padding == 'causal':
+ op_padding = 'valid'
+ else:
+ op_padding = self.padding
self._convolution_op = nn_ops.Convolution(
input_shape,
filter_shape=self.kernel.get_shape(),
dilation_rate=self.dilation_rate,
strides=self.strides,
- padding=self.padding.upper(),
+ padding=op_padding.upper(),
data_format=conv_utils.convert_data_format(self.data_format,
self.rank + 2))
self.built = True
@@ -264,6 +272,15 @@ class Conv(Layer):
base_config = super(Conv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def _compute_causal_padding(self):
+ """Calculates padding for 'causal' option for 1-d conv layers."""
+ left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
+ if self.data_format == 'channels_last':
+ causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
+ else:
+ causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
+ return causal_padding
+
@tf_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
class Conv1D(Conv):
@@ -361,6 +378,11 @@ class Conv1D(Conv):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ def call(self, inputs):
+ if self.padding == 'causal':
+ inputs = array_ops.pad(inputs, self._compute_causal_padding())
+ return super(Conv1D, self).call(inputs)
+
@tf_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
class Conv2D(Conv):
@@ -1261,31 +1283,44 @@ class SeparableConv(Conv):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'data_format': self.data_format,
- 'dilation_rate': self.dilation_rate,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'data_format':
+ self.data_format,
+ 'depth_multiplier':
+ self.depth_multiplier,
+ 'dilation_rate':
+ self.dilation_rate,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
'depthwise_initializer':
initializers.serialize(self.depthwise_initializer),
'pointwise_initializer':
initializers.serialize(self.pointwise_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
'depthwise_regularizer':
regularizers.serialize(self.depthwise_regularizer),
'pointwise_regularizer':
regularizers.serialize(self.pointwise_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'depthwise_constraint':
constraints.serialize(self.depthwise_constraint),
'pointwise_constraint':
constraints.serialize(self.pointwise_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(SeparableConv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -1311,7 +1346,7 @@ class SeparableConv1D(SeparableConv):
of the convolution.
Specifying any `stride` value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
+ padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
@@ -1397,6 +1432,8 @@ class SeparableConv1D(SeparableConv):
**kwargs)
def call(self, inputs):
+ if self.padding == 'causal':
+ inputs = array_ops.pad(inputs, self._compute_causal_padding())
if self.data_format == 'channels_last':
strides = (1,) + self.strides * 2 + (1,)
spatial_start_dim = 1
@@ -1411,12 +1448,16 @@ class SeparableConv1D(SeparableConv):
pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
dilation_rate = (1,) + self.dilation_rate
+ if self.padding == 'causal':
+ op_padding = 'valid'
+ else:
+ op_padding = self.padding
outputs = nn.separable_conv2d(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=strides,
- padding=self.padding.upper(),
+ padding=op_padding.upper(),
rate=dilation_rate,
data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index f904744422..2d3d38a5ce 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -52,7 +52,7 @@ class Convolution1DTest(test.TestCase):
'kernel_size': 3,
}
- self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
self._run_test(kwargs, 'strides', [2])
self._run_test(kwargs, 'dilation_rate', [2])
@@ -329,7 +329,7 @@ class SeparableConv1DTest(test.TestCase):
'kernel_size': 3,
}
- self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
self._run_test(kwargs, 'strides', [2])
self._run_test(kwargs, 'dilation_rate', [2])
self._run_test(kwargs, 'depth_multiplier', [2])
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index afef997b00..9988c9fae5 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -87,7 +87,7 @@ class GRULayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class GRULayerTest(test.TestCase):
def test_regularizers_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class GRULayerTest(test.TestCase):
def test_constraints_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class GRULayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_GRU(self):
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 9802820fd0..f536915324 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -102,7 +102,7 @@ class LSTMLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -161,7 +161,7 @@ class LSTMLayerTest(test.TestCase):
def test_regularizers_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -180,7 +180,7 @@ class LSTMLayerTest(test.TestCase):
def test_constraints_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -200,7 +200,7 @@ class LSTMLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_LSTM(self):
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
@@ -225,7 +225,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -252,7 +252,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with non-Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.backend.random_normal_variable(
@@ -275,7 +275,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LSTM(units, stateful=True)
layer.build((num_samples, timesteps, embedding_dim))
layer.reset_states()
@@ -306,7 +306,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((timesteps, embedding_dim))
_ = keras.layers.Masking()(inputs)
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -329,7 +329,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, stateful=True)
outputs = layer(inputs)
@@ -347,7 +347,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, return_sequences=True)
outputs = layer(inputs)
@@ -366,7 +366,7 @@ class LSTMLayerTest(test.TestCase):
num_states = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
main_inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index a3861e44d5..b9e90095e4 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -530,7 +530,9 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
- def test_stacked_rnn_dropout(self):
+ def DISABLED_test_stacked_rnn_dropout(self):
+ # Temporarily disabled test due an occasional Grappler segfault.
+ # See b/115523414
cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
layer = keras.layers.RNN(cells)
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 1429537648..2f2295a793 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -87,7 +87,7 @@ class SimpleRNNLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_regularizers_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_constraints_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class SimpleRNNLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_SimpleRNN(self):
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 71c1987cee..3a1b00041f 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -463,7 +463,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = SimpleTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -481,7 +481,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = MultiIOTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -501,7 +501,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True)
model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
@@ -521,7 +521,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 1000
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = MultiIOTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -610,7 +610,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel1()
x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -631,7 +631,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel2()
x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -655,7 +655,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel3()
x = array_ops.ones(shape=[100, 784], dtype='float32')
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index f0733a9105..41c5e3cccf 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -444,6 +444,8 @@ def clone_and_build_model(
clone = model
_in_place_subclassed_model_reset(clone)
if input_tensors is not None:
+ if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
+ input_tensors = input_tensors[0]
clone._set_inputs(input_tensors)
# Compile/Build model
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 9a68fc0e35..8d7493462e 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -85,23 +85,23 @@ def _test_optimizer(optimizer, target=0.75):
class KerasOptimizersTest(test.TestCase):
def test_sgd(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
nesterov=True))
def test_rmsprop(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.RMSprop())
_test_optimizer(keras.optimizers.RMSprop(decay=1e-3))
def test_adagrad(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adagrad())
_test_optimizer(keras.optimizers.Adagrad(decay=1e-3))
def test_adadelta(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adadelta(), target=0.6)
# Accuracy seems dependent on the initialization. Even adding tf.Print
# nodes in the graph seemed to affect the initialization seed, and hence
@@ -109,28 +109,28 @@ class KerasOptimizersTest(test.TestCase):
_test_optimizer(keras.optimizers.Adadelta(decay=1e-3), target=0.4)
def test_adam(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adam())
_test_optimizer(keras.optimizers.Adam(decay=1e-3))
_test_optimizer(keras.optimizers.Adam(amsgrad=True))
def test_adamax(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adamax())
_test_optimizer(keras.optimizers.Adamax(decay=1e-3))
def test_nadam(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Nadam())
def test_clipnorm(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
clipnorm=0.5))
def test_clipvalue(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
clipvalue=0.5))
@@ -158,7 +158,7 @@ class KerasOptimizersTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
- with self.test_session():
+ with self.cached_session():
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
model = keras.models.Sequential()
model.add(keras.layers.Dense(
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 58405c550b..501b50ba5f 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -29,7 +29,8 @@ from tensorflow.python.util import tf_inspect
def get_test_data(train_samples,
test_samples,
input_shape,
- num_classes):
+ num_classes,
+ random_seed=None):
"""Generates test data to train a model on.
Arguments:
@@ -37,10 +38,13 @@ def get_test_data(train_samples,
test_samples: Integer, how many test samples to generate.
input_shape: Tuple of integers, shape of the inputs.
num_classes: Integer, number of classes for the data and targets.
+ random_seed: Integer, random seed used by numpy to generate data.
Returns:
A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
+ if random_seed is not None:
+ np.random.seed(random_seed)
num_sample = train_samples + test_samples
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
y = np.random.randint(0, num_classes, size=(num_sample,))
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 3a176c3316..8ebca1418d 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -93,7 +93,7 @@ def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
Arguments:
input_length: integer.
filter_size: integer.
- padding: one of "same", "valid", "full".
+ padding: one of "same", "valid", "full", "causal"
stride: integer.
dilation: dilation rate, integer.
@@ -102,9 +102,9 @@ def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
"""
if input_length is None:
return None
- assert padding in {'same', 'valid', 'full'}
+ assert padding in {'same', 'valid', 'full', 'causal'}
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
- if padding == 'same':
+ if padding in ['same', 'causal']:
output_length = input_length
elif padding == 'valid':
output_length = input_length - dilated_filter_size + 1
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index c1ee34ae46..b736daa46d 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -40,6 +40,7 @@ from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -93,6 +94,11 @@ else:
from six.moves.urllib.request import urlretrieve
+def is_generator_or_sequence(x):
+ """Check if `x` is a Keras generator type."""
+ return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
+
+
def _extract_archive(file_path, path='.', archive_format='auto'):
"""Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
@@ -494,6 +500,7 @@ class SequenceEnqueuer(object):
raise NotImplementedError
+@tf_export('keras.utils.OrderedEnqueuer')
class OrderedEnqueuer(SequenceEnqueuer):
"""Builds a Enqueuer from a Sequence.
@@ -550,7 +557,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
workers, initializer=init_pool, initargs=(seqs,))
else:
- # We do not need the init since it's threads.
+ # We do not need the init since it's threads.
self.executor_fn = lambda _: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 1f28c59ea4..158a9a5e76 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel
from tensorflow.python.util.tf_export import tf_export
+@tf_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0403211d92..6bba99b9e7 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -286,7 +286,10 @@ tf_py_test(
srcs = ["decode_csv_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python/eager:context",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
],
)
@@ -1011,6 +1014,7 @@ tf_py_test(
size = "small",
srcs = ["substr_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
diff --git a/tensorflow/python/kernel_tests/accumulate_n_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py
index b793906fac..0bc5268f38 100644
--- a/tensorflow/python/kernel_tests/accumulate_n_test.py
+++ b/tensorflow/python/kernel_tests/accumulate_n_test.py
@@ -76,7 +76,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
# Putting them here so that everything that exercises AccumulateNV2 is in
# one place and the default build runs all unit tests.
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
random_arrays = [
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
]
@@ -91,27 +91,27 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
self.assertAllClose(np_val, tf_val.eval())
def testZeroArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
tf_val = math_ops.accumulate_n([])
tf_val.eval()
def testWrongShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(0.2)
b = variables.Variable(0.1)
math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[]
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(np.array([0.1, 0.2]))
b = variables.Variable(np.array([[0.3], [0.4]]))
math_ops.accumulate_n([a, b])
def testWrongType(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32)
@@ -119,7 +119,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
math_ops.accumulate_n([a], tensor_dtype=np.int32)
diff --git a/tensorflow/python/kernel_tests/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py
index 5e0d87c783..d267e49752 100644
--- a/tensorflow/python/kernel_tests/ackermann_test.py
+++ b/tensorflow/python/kernel_tests/ackermann_test.py
@@ -34,7 +34,7 @@ class AckermannTest(test.TestCase):
self.assertEqual(len(ackermann.OP_LIST.op), 1)
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
- with self.test_session():
+ with self.cached_session():
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 1202c463e8..127d14c250 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -104,20 +104,20 @@ class ArgMaxTest(test.TestCase):
self._testDim(np.int64)
def testEmpty(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
with self.assertRaisesOpError(
r"Reduction axis 0 is empty in shape \[0\]"):
op([], 0).eval()
def testDefaultAxis(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
ans = op([1]).eval()
self.assertAllEqual(ans, 0)
def testOutputEmpty(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
ret = op(array_ops.zeros(shape=[1, 0, 2]), axis=-1).eval()
self.assertEqual(ret.shape, (1, 0))
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index a164682227..573bb8614f 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -50,7 +50,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testNonBatchMatrix(self):
matrix = [[1, 2, 3], [4, 5, 6]] # Shape (2, 3)
expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2)
- with self.test_session():
+ with self.cached_session():
transposed = array_ops.matrix_transpose(matrix)
self.assertEqual((3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
@@ -58,7 +58,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testConjugate(self):
m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
- with self.test_session():
+ with self.cached_session():
matrix = ops.convert_to_tensor(m)
transposed = array_ops.matrix_transpose(matrix, conjugate=True)
self.assertEqual((3, 2), transposed.get_shape())
@@ -71,7 +71,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_1_t = [[11, 44], [22, 55], [33, 66]]
batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2)
- with self.test_session():
+ with self.cached_session():
transposed = array_ops.matrix_transpose(batch_matrix)
self.assertEqual((2, 3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
@@ -79,7 +79,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testNonBatchMatrixDynamicallyDefined(self):
matrix = [[1, 2, 3], [4, 5, 6]] # Shape (2, 3)
expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2)
- with self.test_session():
+ with self.cached_session():
matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(matrix_ph)
self.assertAllEqual(
@@ -94,7 +94,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_1_t = [[11, 44], [22, 55], [33, 66]]
batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2)
- with self.test_session():
+ with self.cached_session():
batch_matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(batch_matrix_ph)
self.assertAllEqual(
@@ -105,7 +105,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
vector = [1, 2, 3]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "should be a "):
array_ops.matrix_transpose(vector)
@@ -129,7 +129,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
masked_arr = arr[:, mask]
elif axis == 2:
masked_arr = arr[:, :, mask]
- with self.test_session():
+ with self.cached_session():
masked_tensor = array_ops.boolean_mask(arr, mask, axis=axis)
# Leading dimension size of masked_tensor is always unknown until runtime
@@ -176,7 +176,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
numpy_result = arr[mask]
tf_result = array_ops.boolean_mask(arr, mask)
self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(numpy_result, tf_result.eval())
def testEmptyInput1D(self):
@@ -185,7 +185,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
numpy_result = arr[mask]
tf_result = array_ops.boolean_mask(arr, mask)
self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(numpy_result, tf_result.eval())
def testEmptyOutput(self):
@@ -199,7 +199,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testWorksWithDimensionsEqualToNoneDuringGraphBuild(self):
# The rank of the mask tensor must be specified. This is explained
# in the docstring as well.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ph_tensor = array_ops.placeholder(dtypes.int32, shape=None)
ph_mask = array_ops.placeholder(dtypes.bool, shape=[None])
@@ -217,7 +217,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testMaskDimensionsSetToNoneRaises(self):
# The rank of the mask tensor must be specified. This is explained
# in the docstring as well.
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.int32, shape=[None, 2])
mask = array_ops.placeholder(dtypes.bool, shape=None)
with self.assertRaisesRegexp(ValueError, "dimensions must be specified"):
@@ -226,21 +226,21 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testMaskHasMoreDimsThanTensorRaises(self):
mask = [[True, True], [False, False]]
tensor = [1, 2, 3, 4]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "incompatible"):
array_ops.boolean_mask(tensor, mask).eval()
def testMaskIsScalarRaises(self):
mask = True
tensor = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mask.*scalar"):
array_ops.boolean_mask(tensor, mask).eval()
def testMaskShapeDifferentThanFirstPartOfTensorShapeRaises(self):
mask = [True, True, True]
tensor = [[1, 2], [3, 4]]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "incompatible"):
array_ops.boolean_mask(tensor, mask).eval()
@@ -345,7 +345,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testInvalid(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
axis = array_ops.placeholder(dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of valid range"):
array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]})
@@ -954,7 +954,7 @@ class StridedSliceAssignChecker(object):
class SliceAssignTest(test_util.TensorFlowTestCase):
def testInvalidSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
foo = constant_op.constant([1, 2, 3])
with self.assertRaisesRegexp(ValueError, "Sliced assignment"
" is only supported for variables"):
@@ -1000,7 +1000,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
"Attempting to use uninitialized value Variable"):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable([1, 2])
sess.run(v[:].assign([1, 2]))
@@ -1019,7 +1019,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
v = resource_variable_ops.ResourceVariable(init_val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(v.initializer)
with self.assertRaises(ValueError):
sess.run(v[:].assign(too_large_val))
@@ -1066,12 +1066,12 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
class SequenceMaskTest(test_util.TensorFlowTestCase):
def testExceptions(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"):
array_ops.sequence_mask([10, 20], [10, 20])
def testOneDimensionalWithMaxlen(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
self.assertAllEqual(res.get_shape(), [3, 5])
self.assertAllEqual(
@@ -1081,7 +1081,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testOneDimensionalDtypeWithoutMaxlen(self):
- with self.test_session():
+ with self.cached_session():
# test dtype and default maxlen:
res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]),
dtype=dtypes.float32)
@@ -1092,7 +1092,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testOneDimensionalWithoutMaxlen(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(
constant_op.constant([0, 1, 4]))
self.assertAllEqual(res.get_shape().as_list(), [3, 4])
@@ -1104,7 +1104,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testTwoDimensional(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
self.assertAllEqual(res.get_shape(), [1, 3, 5])
self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [
@@ -1137,7 +1137,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
[[True, False, False, False, False], [True, True, True, False, False],
[True, True, False, False, False]])
- with self.test_session():
+ with self.cached_session():
check_dtypes(dtypes.int32, dtypes.int32)
check_dtypes(dtypes.int32, dtypes.int64)
check_dtypes(dtypes.int64, dtypes.int32)
@@ -1216,7 +1216,7 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
# TODO(b/73086570): Reenable test.
@unittest.skip("Test does not pass internally.")
def testUnravelIndex(self):
- with self.test_session():
+ with self.cached_session():
for dtype in [dtypes.int32, dtypes.int64]:
indices_1 = constant_op.constant(1621, dtype=dtype)
dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
@@ -1237,13 +1237,13 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.constant(10)
guarantee_a = array_ops.guarantee_const(a)
self.assertEqual(10, guarantee_a.eval())
def testVariables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for use_resource in [False, True]:
a = variable_scope.get_variable(
"var_{}".format(use_resource), [],
@@ -1254,7 +1254,7 @@ class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, guarantee_a.eval())
def testResourceRejection(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variable_scope.get_variable(
"resource_var", [],
initializer=init_ops.constant_initializer(10.0),
diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py
index 51aa17babe..dd4a90e5f6 100644
--- a/tensorflow/python/kernel_tests/as_string_op_test.py
+++ b/tensorflow/python/kernel_tests/as_string_op_test.py
@@ -32,7 +32,7 @@ class AsStringOpTest(test.TestCase):
0, 1, -1, 0.5, 0.25, 0.125, float("INF"), float("NAN"), float("-INF")
]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.float32, dtypes.float64):
input_ = array_ops.placeholder(dtype)
@@ -84,7 +84,7 @@ class AsStringOpTest(test.TestCase):
int_inputs_ = [0, -1, 1, -128, 127, -101, 101, -0]
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.int32, dtypes.int64, dtypes.int8):
input_ = array_ops.placeholder(dtype)
@@ -117,7 +117,7 @@ class AsStringOpTest(test.TestCase):
# testing int8
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
input_ = array_ops.placeholder(dtypes.int32)
int_inputs_ = [np.iinfo(np.int32).min, np.iinfo(np.int32).max]
output = string_ops.as_string(input_)
@@ -133,7 +133,7 @@ class AsStringOpTest(test.TestCase):
def testHalfInt(self):
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
input_ = array_ops.placeholder(dtypes.int16)
int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max]
output = string_ops.as_string(input_)
@@ -144,7 +144,7 @@ class AsStringOpTest(test.TestCase):
bool_inputs_ = [False, True]
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.bool,):
input_ = array_ops.placeholder(dtype)
@@ -159,7 +159,7 @@ class AsStringOpTest(test.TestCase):
]
complex_inputs_ = [(x + (x + 1) * 1j) for x in float_inputs_]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.complex64, dtypes.complex128):
input_ = array_ops.placeholder(dtype)
diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py
index b98e5fd386..6b16fca29d 100644
--- a/tensorflow/python/kernel_tests/atrous_convolution_test.py
+++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py
@@ -263,7 +263,7 @@ class AtrousConvolutionTest(test.TestCase):
self.assertLess(err, err_tolerance)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for rate_width in range(1, 3):
for rate_height in range(1, 3):
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
index fb74698660..1e09ba5b65 100644
--- a/tensorflow/python/kernel_tests/attention_ops_test.py
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -84,7 +84,7 @@ class ExtractGlimpseTest(test.TestCase):
image_ops.extract_glimpse(t_cols_4d, t1, t2), [0, 2, 1, 3]))
# Evaluate the TensorFlow Graph.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols])
# Check dimensions of returned glimpse.
@@ -118,7 +118,7 @@ class ExtractGlimpseTest(test.TestCase):
def testEmptyTensor(self):
empty_image = np.zeros((0, 4, 3, 0))
offsets = np.zeros((0, 2))
- with self.test_session():
+ with self.cached_session():
result = image_ops.extract_glimpse(empty_image, [1, 1], offsets)
self.assertAllEqual(
np.zeros(
diff --git a/tensorflow/python/kernel_tests/barrier_ops_test.py b/tensorflow/python/kernel_tests/barrier_ops_test.py
index 7f49c63957..4d36b3a465 100644
--- a/tensorflow/python/kernel_tests/barrier_ops_test.py
+++ b/tensorflow/python/kernel_tests/barrier_ops_test.py
@@ -67,7 +67,7 @@ class BarrierTest(test.TestCase):
""", b.barrier_ref.op.node_def)
def testInsertMany(self):
- with self.test_session():
+ with self.cached_session():
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -83,7 +83,7 @@ class BarrierTest(test.TestCase):
self.assertEquals(size_t.eval(), [3])
def testInsertManyEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
error_message = ("Empty tensors are not supported, but received shape "
r"\'\(0,\)\' at index 1")
with self.assertRaisesRegexp(ValueError, error_message):
@@ -91,7 +91,7 @@ class BarrierTest(test.TestCase):
(dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
def testInsertManyEmptyTensorUnknown(self):
- with self.test_session():
+ with self.cached_session():
b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
size_t = b.ready_size()
self.assertEqual([], size_t.get_shape())
@@ -103,7 +103,7 @@ class BarrierTest(test.TestCase):
insert_0_op.run()
def testTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -128,7 +128,7 @@ class BarrierTest(test.TestCase):
self.assertEqual(values_1_val[idx], v1)
def testTakeManySmallBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -192,7 +192,7 @@ class BarrierTest(test.TestCase):
insert_1_3_op.run()
def testUseBarrierWithShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B")
size_t = b.ready_size()
@@ -221,7 +221,7 @@ class BarrierTest(test.TestCase):
self.assertAllEqual(values_1_val[idx], v1)
def testParallelInsertMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
@@ -241,7 +241,7 @@ class BarrierTest(test.TestCase):
self.assertEqual(values_val[idx], v)
def testParallelTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
@@ -275,7 +275,7 @@ class BarrierTest(test.TestCase):
zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)])
def testBlockingTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
keys = [str(x).encode("ascii") for x in range(10)]
values = [float(x) for x in range(10)]
@@ -297,7 +297,7 @@ class BarrierTest(test.TestCase):
t.join()
def testParallelInsertManyTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
@@ -376,7 +376,7 @@ class BarrierTest(test.TestCase):
self.assertAllEqual(taken_i["values_1"], expected_values_1)
def testClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -434,7 +434,7 @@ class BarrierTest(test.TestCase):
sess.run(take_t[0])
def testCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -487,7 +487,7 @@ class BarrierTest(test.TestCase):
sess.run(take_t[0])
def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
take_t = b.take_many(1, allow_small_batch=True)
@@ -500,7 +500,7 @@ class BarrierTest(test.TestCase):
self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True)
def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 50
@@ -576,7 +576,7 @@ class BarrierTest(test.TestCase):
self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True)
def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
@@ -676,7 +676,7 @@ class BarrierTest(test.TestCase):
self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True)
def testIncompatibleSharedBarrierErrors(self):
- with self.test_session():
+ with self.cached_session():
# Do component types and shapes.
b_a_1 = data_flow_ops.Barrier(
(dtypes.float32,), shapes=(()), shared_name="b_a")
diff --git a/tensorflow/python/kernel_tests/base64_ops_test.py b/tensorflow/python/kernel_tests/base64_ops_test.py
index be96f45497..1b399942ef 100644
--- a/tensorflow/python/kernel_tests/base64_ops_test.py
+++ b/tensorflow/python/kernel_tests/base64_ops_test.py
@@ -48,7 +48,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
return base64_msg
def _RunTest(self, msg, pad):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if pad:
encoded, decoded = sess.run([self._encoded_t, self._decoded_t],
feed_dict={self._msg: msg})
@@ -92,7 +92,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
encoded = string_ops.encode_base64(msg, pad=pad)
decoded = string_ops.decode_base64(encoded)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
encoded_value, decoded_value = sess.run([encoded, decoded])
self.assertEqual(encoded_value.shape, msg.shape)
@@ -102,7 +102,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
def try_decode(enc):
self._decoded_f.eval(feed_dict={self._encoded_f: enc})
- with self.test_session():
+ with self.cached_session():
# Invalid length.
msg = np.random.bytes(99)
enc = base64.urlsafe_b64encode(msg)
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index 987a6ffcd4..e651fa0070 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -174,7 +174,7 @@ class BroadcastSimpleTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
@@ -195,7 +195,7 @@ class BroadcastSimpleTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 8e7ae89f9d..7dd347989a 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -86,7 +86,7 @@ class GatherTest(test.TestCase):
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
- with self.test_session():
+ with self.cached_session():
indices_tf = constant_op.constant([1])
self.assertAllEqual([[b"qwer", b"uiop"]],
array_ops.batch_gather(params, indices_tf).eval())
diff --git a/tensorflow/python/kernel_tests/batchtospace_op_test.py b/tensorflow/python/kernel_tests/batchtospace_op_test.py
index 6143cd3baa..03f3f64353 100644
--- a/tensorflow/python/kernel_tests/batchtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/batchtospace_op_test.py
@@ -60,7 +60,7 @@ class BatchToSpaceDepthToSpace(test.TestCase, PythonOpImpl):
array_ops.depth_to_space(
array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
[3, 1, 2, 0])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(y1.eval(), y2.eval())
@@ -235,7 +235,7 @@ class BatchToSpaceGradientTest(test.TestCase, PythonOpImpl):
# Check the gradients.
def _checkGrad(self, x, crops, block_size):
assert 4 == x.ndim
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = self.batch_to_space(tf_x, crops, block_size)
epsilon = 1e-5
@@ -293,7 +293,7 @@ class BatchToSpaceNDGradientTest(test.TestCase):
block_shape = np.array(block_shape)
crops = constant_op.constant(
np.array(crops).reshape((len(block_shape), 2)), crops_dtype)
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = array_ops.batch_to_space_nd(tf_x, block_shape, crops)
epsilon = 1e-5
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
index 3305e55c05..3ec820aead 100644
--- a/tensorflow/python/kernel_tests/bcast_ops_test.py
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -28,11 +28,11 @@ from tensorflow.python.platform import test
class BcastOpsTest(test.TestCase):
def _GetBroadcastShape(self, xs, ys):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(broadcast_args(xs, ys))
def _GetGradientArgs(self, xs, ys):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(broadcast_gradient_args(xs, ys))
def testBasic(self):
diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py
index 16fdedac41..92d21462d5 100644
--- a/tensorflow/python/kernel_tests/betainc_op_test.py
+++ b/tensorflow/python/kernel_tests/betainc_op_test.py
@@ -47,7 +47,7 @@ class BetaincTest(test.TestCase):
tf_b_s = constant_op.constant(b_s, dtype=dtype)
tf_x_s = constant_op.constant(x_s, dtype=dtype)
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
- with self.test_session():
+ with self.cached_session():
tf_out = tf_out_t.eval()
scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
@@ -60,13 +60,13 @@ class BetaincTest(test.TestCase):
# Test out-of-range values (most should return nan output)
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
- with self.test_session():
+ with self.cached_session():
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
# Test broadcasting between scalars and other shapes
- with self.test_session():
+ with self.cached_session():
self.assertAllCloseAccordingToType(
special.betainc(0.1, b_s, x_s).astype(np_dt),
math_ops.betainc(0.1, b_s, x_s).eval(),
@@ -96,7 +96,7 @@ class BetaincTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "must be equal"):
math_ops.betainc(0.5, [0.5], [[0.5]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Shapes of .* are inconsistent"):
a_p = array_ops.placeholder(dtype)
b_p = array_ops.placeholder(dtype)
@@ -140,7 +140,7 @@ class BetaincTest(test.TestCase):
self._testBetaInc(a_s, b_s, x_s, dtypes.float32)
def testBetaIncFpropAndBpropAreNeverNAN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
space = np.logspace(-8, 5).tolist()
space_x = np.linspace(1e-16, 1 - 1e-16).tolist()
ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x)))
@@ -161,7 +161,7 @@ class BetaincTest(test.TestCase):
def testBetaIncGrads(self):
err_tolerance = 1e-3
- with self.test_session():
+ with self.cached_session():
# Test gradient
ga_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty)
gb_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty)
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 2767df127e..8a58b3f97e 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -93,7 +93,7 @@ class BincountTest(test_util.TensorFlowTestCase):
def test_negative(self):
# unsorted_segment_sum will only report InvalidArgumentError on CPU
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
index 4f92ab0795..20446781f0 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/BUILD
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -74,3 +74,16 @@ tf_py_test(
"//tensorflow/python:resources",
],
)
+
+tf_py_test(
+ name = "quantile_ops_test",
+ size = "small",
+ srcs = ["quantile_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resources",
+ ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 4e31b1ea2a..dee96102fb 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -30,7 +30,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionOnEmptyEnsemble(self):
"""Tests that prediction on a dummy ensemble does not fail."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create a dummy ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto='')
@@ -63,7 +63,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testNoCachedPredictionButTreeExists(self):
"""Tests that predictions are updated once trees are added."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -129,7 +129,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionIsCurrent(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -201,7 +201,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromTheSameTree(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -315,7 +315,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromPreviousTree(self):
"""Tests the predictions work when we have cache from previous trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -447,7 +447,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -577,7 +577,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -722,7 +722,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionTheWholeTreeWasPruned(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -794,7 +794,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def testPredictionOnEmptyEnsemble(self):
"""Tests that prediction on a empty ensemble does not fail."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto='')
@@ -816,7 +816,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def testPredictionMultipleTree(self):
"""Tests the predictions work when we have multiple trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -930,7 +930,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
def testContribsMultipleTree(self):
"""Tests that the contribs work when we have multiple trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
new file mode 100644
index 0000000000..c71b8df4ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for checking quantile related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
+from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
+from tensorflow.python.platform import googletest
+
+
+class QuantileOpsTest(test_util.TensorFlowTestCase):
+
+ def create_resource(self, name, eps, max_elements, num_streams=1):
+ quantile_accumulator_handle = resource_handle_op(
+ container="", shared_name=name, name=name)
+ create_op = boosted_trees_ops.create_quantile_stream_resource(
+ quantile_accumulator_handle,
+ epsilon=eps,
+ max_elements=max_elements,
+ num_streams=num_streams)
+ is_initialized_op = resource_initialized(quantile_accumulator_handle)
+ resources.register_resource(quantile_accumulator_handle, create_op,
+ is_initialized_op)
+ return quantile_accumulator_handle
+
+ def setUp(self):
+ """Sets up the quantile ops test as follows.
+
+ Create a batch of 6 examples having 2 features
+ The data looks like this
+ | Instance | instance weights | Feature 0 | Feature 1
+ | 0 | 10 | 1.2 | 2.3
+ | 1 | 1 | 12.1 | 1.2
+ | 2 | 1 | 0.3 | 1.1
+ | 3 | 1 | 0.5 | 2.6
+ | 4 | 1 | 0.6 | 3.2
+ | 5 | 1 | 2.2 | 0.8
+ """
+
+ self._feature_0 = constant_op.constant(
+ [[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
+ self._feature_1 = constant_op.constant(
+ [[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
+ self._feature_0_boundaries = constant_op.constant(
+ [0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
+ self._feature_1_boundaries = constant_op.constant(
+ [0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
+ self._feature_0_quantiles = constant_op.constant(
+ [[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
+ self._feature_1_quantiles = constant_op.constant(
+ [[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
+
+ self._example_weights = constant_op.constant(
+ [10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
+
+ self.eps = 0.01
+ self.max_elements = 1 << 16
+ self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
+
+ def testBasicQuantileBucketsSingleResource(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle = self.create_resource("floats", self.eps,
+ self.max_elements, 2)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle, summaries)
+ flush_op = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle, self.num_quantiles)
+ buckets = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle, num_features=2)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], buckets)
+ sess.run(summary_op)
+ sess.run(flush_op)
+ self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+ def testBasicQuantileBucketsMultipleResources(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
+ self.max_elements)
+ quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
+ self.max_elements)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op_0 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_0,
+ [summaries[0]])
+ summary_op_1 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_1,
+ [summaries[1]])
+ flush_op_0 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_0, self.num_quantiles)
+ flush_op_1 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_1, self.num_quantiles)
+ bucket_0 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_0, num_features=1)
+ bucket_1 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_1, num_features=1)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], bucket_0 + bucket_1)
+ sess.run([summary_op_0, summary_op_1])
+ sess.run([flush_op_0, flush_op_1])
+ self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
index d5f0c22d6e..65bb9ab55f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
@@ -31,7 +31,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
"""Tests resource_ops."""
def testCreate(self):
- with self.test_session():
+ with self.cached_session():
ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
resources.initialize_resources(resources.shared_resources()).run()
stamp_token = ensemble.get_stamp_token()
@@ -44,7 +44,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([0, 1], nodes_range.eval())
def testCreateWithProto(self):
- with self.test_session():
+ with self.cached_session():
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -161,7 +161,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([16, 19], nodes_range.eval())
def testSerializeDeserialize(self):
- with self.test_session():
+ with self.cached_session():
# Initialize.
ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
resources.initialize_resources(resources.shared_resources()).run()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 568e695fd5..09e9cfa3af 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -30,7 +30,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithoutRegularization(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -78,7 +78,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithL2(self):
"""Testing Gain calculation with L2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -126,7 +126,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithL1(self):
"""Testing Gain calculation with L1."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -177,7 +177,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithTreeComplexity(self):
"""Testing Gain calculation with L2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -229,7 +229,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithMinNodeWeight(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -276,7 +276,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -329,7 +329,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummarySimple(self):
"""Simple test for MakeStatsSummary."""
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[[[1., 5.], [2., 6.]], [[3., 7.], [4., 8.]]]],
boosted_trees_ops.make_stats_summary(
node_ids=[0, 0, 1, 1],
@@ -341,7 +341,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummaryAccumulate(self):
"""Tests that Summary actually accumulates."""
- with self.test_session():
+ with self.cached_session():
max_splits = 3
num_buckets = 4
node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -363,7 +363,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummaryMultipleFeatures(self):
"""Tests that MakeStatsSummary works for multiple features."""
- with self.test_session():
+ with self.cached_session():
max_splits = 3
num_buckets = 4
node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -392,7 +392,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
result.eval())
def _verify_precision(self, length):
- with self.test_session():
+ with self.cached_session():
max_splits = 1
num_buckets = 1
node_ids = array_ops.fill([length], 0)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index d55240297a..ea022820e4 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -32,7 +32,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsemble(self):
"""Test growing an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -141,7 +141,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testBiasCenteringOnEmptyEnsemble(self):
"""Test growing with bias centering on an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -184,7 +184,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -368,7 +368,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeFinalized(self):
"""Test growing an existing ensemble with the last tree finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -517,7 +517,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPrePruning(self):
"""Test growing an existing ensemble with pre-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -673,7 +673,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDueToEmptySplits(self):
"""Test that the metadata is updated even though we can't split."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -784,7 +784,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDuePrePruning(self):
"""Test metadata is updated correctly when no split due to prepruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -919,7 +919,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfSomeNodes(self):
"""Test growing an ensemble with post-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
@@ -1253,7 +1253,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfAllNodes(self):
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
@@ -1436,7 +1436,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningChangesNothing(self):
"""Test growing an ensemble with post-pruning with all gains >0."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index 6a1bd958ba..bd2339f31d 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -21,8 +21,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import test as test_lib
@@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
# check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape)
+ def testGradientForScalar(self):
+ # TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
+ # hence we make this test cpu-only.
+ with ops.device("cpu:0"):
+ x = constant_op.constant(1, dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 4, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithSameRank(self):
+ x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 5, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithIncreasingRank(self):
+ x = constant_op.constant([[1], [2]],
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 2, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithBroadcastAllDimensions(self):
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 4, 6])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
index 28b3dc45e9..b19077db56 100644
--- a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
+++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
@@ -38,7 +38,7 @@ class RangeSamplerOpsTest(test.TestCase):
TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
def testTrueCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = constant_op.constant([0, 0, 1, 1, 2, 2])
true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3])
true_candidates_matrix = array_ops.reshape(
@@ -50,7 +50,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
def testSampledCandidates(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -62,7 +62,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
def testTrueLogExpectedCount(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
_, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -77,7 +77,7 @@ class RangeSamplerOpsTest(test.TestCase):
[self.BATCH_SIZE, self.NUM_TRUE])
def testSampledLogExpectedCount(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
_, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( # pylint: disable=line-too-long
@@ -90,7 +90,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
def testAccidentalHits(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -109,7 +109,7 @@ class RangeSamplerOpsTest(test.TestCase):
def testSeed(self):
def draw(seed):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler(
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 214d5cb3c0..c90520e46d 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -174,7 +174,7 @@ class CastOpTest(test.TestCase):
self.assertAllEqual(np.isnan(self._cast(np.nan, np.float64, True)), True)
def _OpError(self, x, dtype, err):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(err):
math_ops.cast(x, dtype).eval()
@@ -182,7 +182,7 @@ class CastOpTest(test.TestCase):
self._OpError(np.arange(0, 10), dtypes.string, "Cast.*int64.*string.*")
def testCastToTypeOfVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(5, dtype=dtypes.float32)
y = variables.Variable(True, dtype=dtypes.bool)
cast = math_ops.cast(y, x.dtype)
@@ -193,7 +193,7 @@ class CastOpTest(test.TestCase):
t = [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
for src_t in t:
for dst_t in t:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, src_t)
z = array_ops.identity(x)
y = math_ops.cast(z, dst_t)
@@ -209,7 +209,7 @@ class SparseTensorCastTest(test.TestCase):
shape = constant_op.constant([3], dtypes.int64)
st = sparse_tensor.SparseTensor(indices, values, shape)
st_cast = math_ops.cast(st, dtypes.float32)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(st_cast.indices.eval(), [[0], [1], [2]])
self.assertAllEqual(st_cast.values.eval(),
np.array([1, 2, 3], np.float32))
@@ -221,7 +221,7 @@ class SaturateCastTest(test.TestCase):
def testSaturate(self):
in_types = dtypes.float32,
out_types = dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.float32
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for in_type in in_types:
for out_type in out_types:
lo, hi = in_type.min, in_type.max
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 680d0c97cc..27a674e223 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -819,6 +820,18 @@ class EnsureShapeTest(test.TestCase):
with self.test_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
+ def testGradient(self):
+ placeholder = array_ops.placeholder(dtypes.float32)
+ derived = check_ops.ensure_shape(placeholder, (None, None))
+ gradient = gradients.gradients(derived, placeholder)
+
+ feed_val = [[4.0], [-1.0]]
+ with self.test_session() as sess:
+ gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
+
+ expected = [[1.0], [1.0]]
+ self.assertAllEqual(gradient_values, expected)
+
class EnsureShapeBenchmark(test.Benchmark):
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
index 7f147ba53a..51611b75af 100644
--- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py
+++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
@@ -57,7 +57,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = range(0, 3)
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -70,7 +70,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = [2, 0, 1]
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -83,7 +83,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=1)
expected_remapping = [0]
expected_num_present = 1
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -98,7 +98,7 @@ class GenerateVocabRemappingTest(test.TestCase):
old_vocab_size=2)
expected_remapping = [-1, 0, 1]
expected_num_present = 2
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -122,7 +122,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
self.old_tensor_name = 'some_scope/matrix'
save = saver.Saver([matrix])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
save.save(sess, self.bundle_file)
@@ -140,7 +140,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=2,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping],
remapped_matrix.eval())
@@ -155,7 +155,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -170,7 +170,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -189,7 +189,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
expected_remapped_matrix = np.reshape(
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_all_missing_rows(self):
@@ -204,7 +204,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
remapped_matrix.eval())
@@ -222,7 +222,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, num_cols)),
remapped_matrix.eval())
@@ -243,7 +243,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(invalid_remapping),
num_cols=self.old_num_cols)
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
# Invalid column remapping.
@@ -255,7 +255,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=self.old_num_rows,
num_cols=len(invalid_remapping))
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
def test_load_and_remap_incorrect_initializing_values(self):
@@ -272,7 +272,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
@@ -284,7 +284,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[0] * 5,
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
@@ -306,7 +306,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase):
initializer=constant_op.constant(np_value, dtype=dtypes.float32),
partitioner=partitioner)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
save = saver.Saver([matrix])
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index de52a70cc0..bb7b645da2 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -39,7 +39,7 @@ class ClipTest(test.TestCase):
min_val = constant_op.constant([0.5, 0.5, 0.5, 0.5], dtype=dtypes.float32)
max_val = constant_op.constant([3.5, 3.5, 3.5, 3.5], dtype=dtypes.float32)
outputs_2 = clip_ops.clip_by_value(inputs, min_val, max_val)
- with self.test_session():
+ with self.cached_session():
error_1 = gradient_checker.compute_gradient_error(inputs, [4], outputs_1,
[4])
self.assertLess(error_1, 1e-4)
@@ -139,7 +139,7 @@ class ClipTest(test.TestCase):
def testClipByValueNonFinite(self):
# TODO(b/78016351): Enable test on GPU once the bug is fixed.
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
np_ans = [float('NaN'), 4.0, -4.0]
clip_value = 4.0
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index c22934ce47..0e59ce6972 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -383,7 +383,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 0)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -397,7 +397,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 1)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -411,7 +411,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 0)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -426,7 +426,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 1)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -441,7 +441,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 2)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -452,7 +452,7 @@ class ConcatOpTest(test.TestCase):
def testIndexedSlicesConcatDim1Grad_UnknownInputDim(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [4, 11, 3]
- with self.test_session():
+ with self.cached_session():
x_1 = array_ops.placeholder(dtypes.float64)
x_2 = array_ops.placeholder(dtypes.float64)
x_3 = array_ops.placeholder(dtypes.float64)
@@ -473,13 +473,13 @@ class ConcatOpTest(test.TestCase):
def testConcatTuple(self):
c1 = np.random.rand(4, 4)
c2 = np.random.rand(4, 4)
- with self.test_session():
+ with self.cached_session():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
- with self.test_session():
+ with self.cached_session():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
with self.assertRaisesRegexp(
@@ -554,7 +554,7 @@ class ConcatOpTest(test.TestCase):
def _testGradientsForAxis(
self, inp_tensors, axis, output_shape, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.concat(inp_tensors, axis)
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
@@ -566,7 +566,7 @@ class ConcatOpTest(test.TestCase):
def _testIndexedSlicesGradientsForAxis(
self, inp_tensors, axis, output_shape, gather_indexes, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.gather(
array_ops.concat(inp_tensors, axis), gather_indexes)
grad_inp = np.random.rand(*output_shape).astype("f")
@@ -631,7 +631,7 @@ class ConcatOffsetTest(test.TestCase):
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
def testNotVector(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
@@ -641,7 +641,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testConcatDimOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(4, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -651,7 +651,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testDimMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
@@ -661,7 +661,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testSizeMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 10], dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 0dc3c53bc0..a1efecf28a 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -107,7 +107,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNoInputs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
def true_fn():
@@ -527,7 +527,7 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testSecondDerivative(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
x = constant_op.constant(3.0, name="x")
@@ -801,7 +801,6 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
- self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -826,7 +825,6 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
def testColocateWithInAndOutOfCond(self):
- self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -874,7 +872,6 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
- self.skipTest("b/112166045")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
def fn():
@@ -895,11 +892,13 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.test_session(
+ graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):
+
def fn2():
- with ops.device("/device:GPU:0"):
+ with ops.device("/device:CPU:1"):
c = constant_op.constant(3.0)
- self.assertEqual("/device:GPU:0", c.op.device)
+ self.assertEqual("/device:CPU:1", c.op.device)
return c
with ops.device("/device:CPU:0"):
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 86802664d1..262352a9af 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -80,26 +80,26 @@ class ConditionalAccumulatorTest(test.TestCase):
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorSetGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
set_global_step_op = q.set_global_step(1)
set_global_step_op.run()
def testAccumulatorApplyGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
accum_op.run()
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
for i in range(len(dtypes)):
@@ -116,7 +116,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(sum(elems) / len(elems), result)
def testAccumulatorMultipleAccumulators(self):
- with self.test_session():
+ with self.cached_session():
q_f32_0 = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
q_f32_1 = data_flow_ops.ConditionalAccumulator(
@@ -135,7 +135,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(result, i + 10.0)
def testAccumulatorApplyAndTakeGradWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(3, 2))
elems = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
@@ -166,7 +166,7 @@ class ConditionalAccumulatorTest(test.TestCase):
q.apply_grad([[1.0], [2.0], [3.0]])
def testAccumulatorDynamicShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=None)
@@ -191,7 +191,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertTrue(is_all_equal)
def testAccumulatorWrongDynamicShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=None)
@@ -209,7 +209,7 @@ class ConditionalAccumulatorTest(test.TestCase):
sess.run(accum_op, feed_dict={x: [[1.0], [2.0], [3.0]]})
def testAccumulatorSizeAfterApplyGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
@@ -220,7 +220,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 2)
def testAccumulatorSizeAfterApplyGradAndTakeGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
@@ -248,7 +248,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorTakeGradMean(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
@@ -307,7 +307,7 @@ class ConditionalAccumulatorTest(test.TestCase):
reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
@@ -322,7 +322,7 @@ class ConditionalAccumulatorTest(test.TestCase):
takeg_t.eval()
def testAccumulatorRepeatedTakeGradMean(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -379,7 +379,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(elems_sum, val)
def testAccumulatorIncrementGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -395,7 +395,7 @@ class ConditionalAccumulatorTest(test.TestCase):
inc_global_step.eval()
def testAccumulatorSetGlobalStepPreventsAccumulation(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -416,7 +416,7 @@ class ConditionalAccumulatorTest(test.TestCase):
if x >= ls), val)
def testParallelApplyGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -441,7 +441,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(val, sum(elems) / len(elems))
def testParallelTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [e for e in range(10)]
@@ -473,7 +473,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testAccumulatorApplyAndBlockingTake(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -506,7 +506,7 @@ class ConditionalAccumulatorTest(test.TestCase):
sess.run(takeg_op)
def testAccumulatorCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
takeg_t = q.take_grad(1)
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 93f5323c41..bc24345261 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -37,7 +37,7 @@ class ConfusionMatrixTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testExample(self):
"""This is a test of the example provided in pydoc."""
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
@@ -49,7 +49,7 @@ class ConfusionMatrixTest(test.TestCase):
def _testConfMatrix(self, labels, predictions, truth, weights=None,
num_classes=None):
- with self.test_session():
+ with self.cached_session():
dtype = predictions.dtype
ans = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtype, weights=weights,
@@ -78,7 +78,7 @@ class ConfusionMatrixTest(test.TestCase):
self._testBasic(dtype=np.int64)
def _testConfMatrixOnTensors(self, tf_dtype, np_dtype):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
m_neg = array_ops.placeholder(dtype=dtypes.float32)
m_pos = array_ops.placeholder(dtype=dtypes.float32)
s = array_ops.placeholder(dtype=dtypes.float32)
@@ -229,7 +229,7 @@ class ConfusionMatrixTest(test.TestCase):
def testOutputIsInt32(self):
labels = np.arange(2)
predictions = np.arange(2)
- with self.test_session():
+ with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int32)
tf_cm = cm.eval()
@@ -238,7 +238,7 @@ class ConfusionMatrixTest(test.TestCase):
def testOutputIsInt64(self):
labels = np.arange(2)
predictions = np.arange(2)
- with self.test_session():
+ with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int64)
tf_cm = cm.eval()
@@ -260,7 +260,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -285,7 +285,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -310,7 +310,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=0))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -336,7 +336,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -362,7 +362,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder, expected_rank_diff=1))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -388,7 +388,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
@@ -415,7 +415,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
@@ -441,7 +441,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
@@ -466,7 +466,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 107ee37fab..d1e4e5477f 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -162,18 +162,18 @@ class ConstantTest(test.TestCase):
logging_const_op.run()
def testStringWithNulls(self):
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(b"\0\0\0\0").eval()
self.assertEqual(len(val), 4)
self.assertEqual(val, b"\0\0\0\0")
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(b"xx\0xx").eval()
self.assertEqual(len(val), 5)
self.assertAllEqual(val, b"xx\0xx")
nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]]
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(nested).eval()
# NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
# numpy array, which loses the null terminators.
@@ -279,7 +279,7 @@ class AsTensorTest(test.TestCase):
self.assertTrue(isinstance(x, ops.Tensor))
def testAsTensorForShapeInput(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(tensor_shape.TensorShape([]))
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual([], x.eval())
@@ -331,7 +331,7 @@ class AsTensorTest(test.TestCase):
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.float32)
def testAsTensorForDimensionInput(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(tensor_shape.TensorShape([1, 2, 3])[1])
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual(2, x.eval())
@@ -367,7 +367,7 @@ class IdentityOpTest(test.TestCase):
class ZerosTest(test.TestCase):
def _Zeros(self, shape):
- with self.test_session():
+ with self.cached_session():
ret = array_ops.zeros(shape)
self.assertEqual(shape, ret.get_shape())
return ret.eval()
@@ -379,13 +379,13 @@ class ZerosTest(test.TestCase):
def testScalar(self):
self.assertEqual(0, self._Zeros([]))
self.assertEqual(0, self._Zeros(()))
- with self.test_session():
+ with self.cached_session():
scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32))
self.assertEqual(0, scalar.eval())
def testDynamicSizes(self):
np_ans = np.array([[0] * 3] * 2)
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of 2 x 3.
d = array_ops.fill([2, 3], 12., name="fill")
# Constructs a tensor of zeros of the same dimensions as "d".
@@ -396,7 +396,7 @@ class ZerosTest(test.TestCase):
self.assertShapeEqual(np_ans, z)
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
d = array_ops.fill([2, 3], 12., name="fill")
self.assertEqual(d.get_shape(), [2, 3])
# Test default type for both constant size and dynamic size
@@ -489,7 +489,7 @@ class ZerosLikeTest(test.TestCase):
def testZerosLikeDtype(self):
# Make sure zeros_like works even for dtypes that cannot be cast between
- with self.test_session():
+ with self.cached_session():
shape = (3, 5)
dtypes = np.float32, np.complex64
for in_type in dtypes:
@@ -533,7 +533,7 @@ class ZerosLikeTest(test.TestCase):
class OnesTest(test.TestCase):
def _Ones(self, shape):
- with self.test_session():
+ with self.cached_session():
ret = array_ops.ones(shape)
self.assertEqual(shape, ret.get_shape())
return ret.eval()
@@ -544,13 +544,13 @@ class OnesTest(test.TestCase):
def testScalar(self):
self.assertEqual(1, self._Ones([]))
self.assertEqual(1, self._Ones(()))
- with self.test_session():
+ with self.cached_session():
scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32))
self.assertEqual(1, scalar.eval())
def testDynamicSizes(self):
np_ans = np.array([[1] * 3] * 2)
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of 2 x 3.
d = array_ops.fill([2, 3], 12., name="fill")
# Constructs a tensor of ones of the same dimensions as "d".
@@ -561,7 +561,7 @@ class OnesTest(test.TestCase):
self.assertShapeEqual(np_ans, z)
def testAutoPack(self):
- with self.test_session():
+ with self.cached_session():
h = array_ops.placeholder(dtypes_lib.int32, shape=[])
w = array_ops.placeholder(dtypes_lib.int32, shape=[])
z = array_ops.ones([h, w])
@@ -569,7 +569,7 @@ class OnesTest(test.TestCase):
self.assertAllEqual(out, np.array([[1] * 16] * 4))
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
d = array_ops.fill([2, 3], 12., name="fill")
self.assertEqual(d.get_shape(), [2, 3])
# Test default type for both constant size and dynamic size
@@ -606,7 +606,7 @@ class OnesLikeTest(test.TestCase):
dtypes_lib.complex128
]:
numpy_dtype = dtype.as_numpy_dtype
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of non-zero values with shape 2 x 3.
d = constant_op.constant(
np.ones(
@@ -672,7 +672,7 @@ class FillTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
def testFillNegative(self):
- with self.test_session():
+ with self.cached_session():
for shape in (-1,), (2, -1), (-1, 2), (-2), (-3):
with self.assertRaises(ValueError):
array_ops.fill(shape, 7)
@@ -703,7 +703,7 @@ class FillTest(test.TestCase):
self.assertEqual([None, 17], f.get_shape().as_list())
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
in_v = constant_op.constant(5.0)
out_shape = [3, 2]
out_filled = array_ops.fill(out_shape, in_v)
@@ -715,7 +715,7 @@ class FillTest(test.TestCase):
class PlaceholderTest(test.TestCase):
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 10)
@@ -727,7 +727,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval()
def testShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 10)
@@ -744,7 +744,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array[:5, :5]})
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=None, name="p")
p_identity = array_ops.identity(p)
# can feed anything
@@ -756,13 +756,13 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array}), feed_array)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[], name="p")
p_identity = array_ops.identity(p)
self.assertAllClose(p_identity.eval(feed_dict={p: 5}), 5)
def testPartialShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 3)
@@ -774,7 +774,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array[:5, :2]})
def testPartialShapeWhenNotFed(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
p_identity = array_ops.identity(p)
@@ -784,7 +784,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval()
def testControlDependency(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.int32, shape=[], name="p")
with ops.control_dependencies([p]):
c = constant_op.constant(5, dtypes_lib.int32)
@@ -872,7 +872,7 @@ versions {
"""
gdef = graph_pb2.GraphDef()
text_format.Merge(graph, gdef)
- with self.test_session():
+ with self.cached_session():
p, ret = importer.import_graph_def(
gdef, return_elements=["Placeholder:0", "add:0"])
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index eac97af4ed..ebeabcfe1a 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -129,7 +129,7 @@ def isum(s, maximum_iterations=None):
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
v = control_flow_ops._Identity(v)
@@ -141,7 +141,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v2.eval())
def testRefEnter(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
@@ -154,7 +154,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v3.eval())
def testRefSwitch(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
p = constant_op.constant(True)
@@ -164,7 +164,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v2.eval())
def testEnterMulExit(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
five = constant_op.constant(5)
@@ -176,7 +176,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
def testEnterShapePropagation(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
# If is_constant=True, the shape information should be propagated.
@@ -190,7 +190,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(enter_v_non_constant.shape, None)
def testSwitchMergeIndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([1, 2, 3, 4, 5, 6])
indices = constant_op.constant([0, 2, 4, 6, 8, 10])
data = ops.IndexedSlices(values, indices)
@@ -204,7 +204,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.arange(0, 12, 2), ind)
def testSwitchDeadBranch(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(True, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -216,7 +216,7 @@ class ControlFlowTest(test.TestCase):
dead_branch.eval()
def testSwitchMergeLess(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
zero = ops.convert_to_tensor(0)
one = ops.convert_to_tensor(1)
@@ -228,7 +228,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.arange(1, 7), result)
def testSwitchMergeAddIdentity(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(False, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -241,7 +241,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
def testSwitchMergeAddMul(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(True, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -255,7 +255,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
def testLoop_false(self):
- with self.test_session():
+ with self.cached_session():
false = ops.convert_to_tensor(False)
n = constant_op.constant(10)
@@ -272,7 +272,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testLoop_1(self):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0)
one = constant_op.constant(1)
n = constant_op.constant(10)
@@ -298,7 +298,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testLoop_2(self):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0)
one = constant_op.constant(1)
n = constant_op.constant(10)
@@ -324,7 +324,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testDifferentFrame(self):
- with self.test_session():
+ with self.cached_session():
data = array_ops.placeholder(dtypes.float32, shape=[])
enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
@@ -333,7 +333,7 @@ class ControlFlowTest(test.TestCase):
res.eval(feed_dict={data: 1.0})
def testCondBool(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296297")
values = constant_op.constant(10)
@@ -352,7 +352,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([None], grad)
def testFetchable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
control_flow_ops.cond(
constant_op.constant(True), lambda: x + 2, lambda: x + 0)
@@ -367,7 +367,7 @@ class ControlFlowTest(test.TestCase):
sess.run(t, feed_dict={x: 3})
def testFeedable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(2)
i0 = constant_op.constant(0)
r = control_flow_ops.while_loop(lambda i: i < 1000,
@@ -384,10 +384,10 @@ class ControlFlowTest(test.TestCase):
sess.run(r, feed_dict={t: 3})
def testCondIndexedSlices(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296180")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
x = ops.IndexedSlices(values, indices)
@@ -402,10 +402,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(0, ind)
def testCondSparseTensor(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296161 (SparseTensors)")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -422,10 +422,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
rv = resource_variable_ops.ResourceVariable(True)
variables.global_variables_initializer().run()
t = ops.convert_to_tensor(1.0)
@@ -438,10 +438,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113293074")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
@@ -484,17 +484,17 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(10)
r = control_flow_ops.cond(
math_ops.less(1, 0), lambda: math_ops.add(x, 1),
@@ -503,10 +503,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(10)
pred = math_ops.less(1, 2)
fn1 = lambda: math_ops.add(x, 1)
@@ -518,10 +518,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, result)
def testCond_4(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
v3 = variables.Variable(7)
@@ -542,7 +542,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(7, v3.eval())
def testCond_5(self):
- with self.test_session():
+ with self.cached_session():
alive = constant_op.constant(True, name="alive")
count = constant_op.constant(0, name="count")
@@ -556,10 +556,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([7])
age = constant_op.constant(3)
@@ -573,7 +573,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([7]), result)
def testCond_7(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(10)
y = constant_op.constant(200)
pred = math_ops.less(1, 2)
@@ -583,10 +583,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@@ -599,10 +599,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([2.0], r.eval())
def testCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/79881896")
- with self.test_session() as sess:
+ with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -617,7 +617,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5, r.eval())
def testUninitializedRefIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@@ -641,7 +641,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([1.0], sess.run(merged_op.output))
def testCondSwitchIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the recv identity is not removed by optimization.
@@ -658,7 +658,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondRecvIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the switch identity is not removed by optimization.
@@ -677,7 +677,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
graph = ops.Graph()
@@ -689,11 +689,11 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
x = constant_op.constant(10.0)
pred = math_ops.less(c, 2)
@@ -706,10 +706,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
def testCondGrad_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
pred = math_ops.less(c, 2)
@@ -726,7 +726,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
def testNestedCond_Simple(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(0., name="X")
y = control_flow_ops.cond(
constant_op.constant(True), lambda: x,
@@ -741,10 +741,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, result.eval())
def testCondGrad_Gather(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113327884")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
pred = math_ops.less(c, 2)
@@ -768,7 +768,7 @@ class ControlFlowTest(test.TestCase):
# Microbenchmark: 256,000 iterations/s.
def testWhile_1(self):
- with self.test_session():
+ with self.cached_session():
n = constant_op.constant(0)
c = lambda x: math_ops.less(x, 10000)
b = lambda x: math_ops.add(x, 1)
@@ -776,7 +776,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10000, r.eval())
def testWhileExternalControlDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
v.initializer.run()
increment = v.assign_add(1.0)
@@ -791,7 +791,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(v.eval(), 1.0)
def testWhileExternalControlDependenciesNoInput(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
v.initializer.run()
increment = v.assign_add(1.0)
@@ -806,7 +806,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(v.eval(), 1.0)
def testWhileWithRefs_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(0)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 100)
@@ -830,19 +830,19 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
def testWhile_2(self):
- with self.test_session():
+ with self.cached_session():
s = constant_op.constant(0)
r = isum(s)
self.assertAllEqual(45, r.eval())
def testWhileWithMaximumIterations(self):
- with self.test_session():
+ with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
def testWhileWithMaximumIterationsAndSingleArgument(self):
- with self.test_session():
+ with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
@@ -916,7 +916,7 @@ class ControlFlowTest(test.TestCase):
_ = gradients_impl.gradients(loop_with_maxiter, v)
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
v = constant_op.constant(1.0)
@@ -1019,7 +1019,7 @@ class ControlFlowTest(test.TestCase):
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.
def testWhile_3(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, m, c, o):
m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
@@ -1039,7 +1039,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10100, result)
def testWhile_4(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, m, c, o):
m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
@@ -1060,7 +1060,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42, result)
def testWhile_5(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, c, o):
c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
@@ -1088,7 +1088,7 @@ class ControlFlowTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("/cpu:0"):
c = constant_op.constant(2)
i0 = constant_op.constant(0)
@@ -1134,7 +1134,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=True)
def testWhileShape(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0)
m = array_ops.ones([2, 2])
c = lambda i, j: math_ops.less(i, 2)
@@ -1151,7 +1151,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.ones((8, 8)), r.eval())
def testWhileWithNonTensorInput_Scalar(self):
- with self.test_session():
+ with self.cached_session():
n = 0
c = lambda x: x < 10000
b = lambda x: x + 1
@@ -1159,7 +1159,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10000, r.eval())
def testWhileWithNonTensorInput_Vector(self):
- with self.test_session():
+ with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
c = lambda x: x[0] < 10000
b = lambda x: array_ops.stack([x[0] + 1])
@@ -1167,7 +1167,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([10000], r.eval())
def testWhileShapeInference(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0)
m = array_ops.ones([2, 2])
c = lambda i, j: math_ops.less(i, 2)
@@ -1192,7 +1192,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [i, m])
def testWhileShapeInferenceSparseTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -1223,7 +1223,7 @@ class ControlFlowTest(test.TestCase):
[i.get_shape(), tensor_shape.TensorShape([5])])
def testWhileShapeInferenceIndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
indices = constant_op.constant([0, 3], name="indices")
shape = constant_op.constant([10, 2], name="dense_shape")
@@ -1313,7 +1313,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhile_2(use_gpu=True)
def testWhileWithControl_1(self):
- with self.test_session():
+ with self.cached_session():
n = constant_op.constant(0)
r = constant_op.constant(0)
condition = lambda n_, r_: math_ops.less(n_, 10)
@@ -1329,7 +1329,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, res[1].eval())
def testWhileWithControl_2(self):
- with self.test_session():
+ with self.cached_session():
r = constant_op.constant(0)
condition = lambda r_: math_ops.less(r_, 10)
@@ -1343,7 +1343,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, res.eval())
def testWhileWithControl_3(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1352,7 +1352,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileWithControl_4(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1362,7 +1362,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileWithControl_5(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1375,12 +1375,12 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const_true = lambda: constant_op.constant(True)
const_false = lambda: constant_op.constant(False)
cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
@@ -1392,10 +1392,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, sess.run(loop))
def testWhileCondWithControl_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
i0 = constant_op.constant(0)
@@ -1417,10 +1417,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(65536.0, v.eval())
def testWhileCondExitControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(1)
def false_branch():
@@ -1443,10 +1443,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
b = lambda x: math_ops.add(x, 1)
@@ -1456,10 +1456,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
b = lambda x: math_ops.add(x, 1)
@@ -1469,7 +1469,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session(use_gpu=use_gpu) as sess:
@@ -1498,10 +1498,10 @@ class ControlFlowTest(test.TestCase):
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
n = ops.convert_to_tensor(10, name="n")
one = ops.convert_to_tensor(1, name="one")
@@ -1516,10 +1516,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
@@ -1527,10 +1527,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
# pylint: disable=undefined-variable
@@ -1544,7 +1544,7 @@ class ControlFlowTest(test.TestCase):
# NOTE: It is ok to have parallel_iterations > 1
def testWhileUpdateVariable_1(self):
- with self.test_session():
+ with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1566,7 +1566,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
def testWhileUpdateVariable_2(self):
- with self.test_session():
+ with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
select2 = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1592,7 +1592,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
def testWhileUpdateVariable_3(self):
- with self.test_session():
+ with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1614,7 +1614,7 @@ class ControlFlowTest(test.TestCase):
# b/24814703
def testWhileUpdateVariable_4(self):
- with self.test_session():
+ with self.cached_session():
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
variables.global_variables_initializer().run()
@@ -1642,7 +1642,7 @@ class ControlFlowTest(test.TestCase):
# b/24736492
def testWhileUpdateVariable_5(self):
- with self.test_session():
+ with self.cached_session():
# Create some variables.
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
@@ -1672,7 +1672,7 @@ class ControlFlowTest(test.TestCase):
# b/24814668
def testWhileUpdateVariable_6(self):
- with self.test_session():
+ with self.cached_session():
# Create some variables.
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
@@ -1701,7 +1701,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
def testWhileQueue_1(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
i = constant_op.constant(0)
@@ -1719,7 +1719,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([i], q.dequeue().eval())
def testWhileStack_1(self):
- with self.test_session():
+ with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
i = constant_op.constant(0)
@@ -1753,7 +1753,7 @@ class ControlFlowTest(test.TestCase):
def _testWhileGrad_ColocateGradients(self, colocate):
gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
- ) else "/device:GPU:0"
+ ) else "/device:CPU:0"
graph = ops.Graph()
with graph.as_default():
@@ -1791,7 +1791,7 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_ColocateGradients(colocate=True)
def testWhileGrad_Square(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1802,7 +1802,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileGrad_Shape(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
v = constant_op.constant([2.0], name="v")
n = constant_op.constant(0, name="n")
@@ -1819,7 +1819,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
def testWhileGrad_BaseShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, [None])
v0 = constant_op.constant([2.0, 2.0], name="v")
c = lambda v: constant_op.constant(False)
@@ -1831,7 +1831,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
def testWhileGrad_MultipleUses(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1842,7 +1842,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(524288.0, r.eval())
def testWhileGrad_LoopAdd(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1872,7 +1872,7 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
@@ -1901,7 +1901,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhileCondWhileGrad(use_gpu=True)
def testWhileGrad_Variable(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable(3.0)
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
@@ -1913,10 +1913,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
x = array_ops.placeholder(dtypes.float32, shape=None)
c = lambda n: math_ops.less(n, 10.0)
@@ -1931,7 +1931,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testGradInWhileWrtInitialLoopVal(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
y = x + 1
@@ -1948,7 +1948,7 @@ class ControlFlowTest(test.TestCase):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
def testWhileGradInWhile(self):
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
x = array_ops.placeholder(dtypes.float32, shape=None)
c = lambda n: math_ops.less(n, 10.0)
@@ -1964,7 +1964,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testCondGradInNestedWhiles(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
@@ -1978,13 +1978,13 @@ class ControlFlowTest(test.TestCase):
i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i_val, x_val = sess.run([i, x])
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
def testWhile_NestedInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2011,7 +2011,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r_flattened))
def testWhile_NestedBadArityFails(self):
- with self.test_session():
+ with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2027,7 +2027,7 @@ class ControlFlowTest(test.TestCase):
control_flow_ops.while_loop(c, b, loop_vars)
def testWhileGrad_ys_xs(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2050,7 +2050,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(120.0, r[0].eval())
def testWhileGrad_Dependency(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2069,7 +2069,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_NoGradient(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -2079,7 +2079,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1.0, r[0].eval())
def testWhileGrad_NoDependency(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
duration = array_ops.zeros([], dtype=dtypes.int32)
@@ -2099,7 +2099,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
def testWhileGrad_Const(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c0 = constant_op.constant(0.0, name="c0")
c1 = constant_op.constant(1.0, name="c1")
duration = constant_op.constant(0, name="t")
@@ -2118,7 +2118,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, sess.run(grad[0]))
def testWhileGrad_SerialTwoLoops(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2136,7 +2136,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_ParallelTwoLoops(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2155,7 +2155,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(64.0, r[0].eval())
def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(1.0, name="y")
@@ -2196,7 +2196,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhileGrad_Simple(use_gpu=True)
def testNestedWhileGrad_SerialInner(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(1.0)
def inner_loop1(s):
@@ -2219,7 +2219,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(256.0, r.eval())
def testNestedWhileGrad_ParallelInner(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(1.0)
def inner_loop1(s):
@@ -2244,7 +2244,7 @@ class ControlFlowTest(test.TestCase):
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def inner_loop(t):
fn = lambda n: n + math_ops.square(var)
@@ -2280,14 +2280,14 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileCondGrad_Simple(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
def testWhileCondGrad_UnknownShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
n = ops.convert_to_tensor(100.0, name="n")
one = ops.convert_to_tensor(1.0, name="one")
@@ -2304,7 +2304,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r)
def testWhileGrad_Concat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
i0 = constant_op.constant(0)
h0 = array_ops.zeros([0, 2])
@@ -2327,7 +2327,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
def testWhileWithRefsWithGradients_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
@@ -2355,7 +2355,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(73, value_x_grad)
def testWhileGrad_IndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant([0, 3], name="indices")
shape = constant_op.constant([10], name="dense_shape")
@@ -2376,7 +2376,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
def testWhileGrad_SparseTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -2398,7 +2398,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
def testCallGradInLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i0 = constant_op.constant(0)
params = constant_op.constant(5.0)
params_1 = math_ops.square(params)
@@ -2417,7 +2417,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(600.0, sess.run(output_grad)[1])
def testWhileAndTensorArray(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
param = constant_op.constant(2.0)
n0 = constant_op.constant(0)
y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
@@ -2436,7 +2436,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(107520.0, sess.run(r))
def testWhileGrad_StopGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2479,7 +2479,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(32.0, r.eval())
def testWhileGrad_StopGradInside(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2498,7 +2498,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(156.0, r.eval())
def testWhileGrad_StopGradInsideNoShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
@@ -2534,7 +2534,7 @@ class ControlFlowTest(test.TestCase):
gradients_impl.gradients(grad_theta_stopped, theta)
def testStopGradOnWhileGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(2.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2562,7 +2562,7 @@ class ControlFlowTest(test.TestCase):
_, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
dy_dq, = gradients_impl.gradients(y, q)
self.assertIsNotNone(dy_dq)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
@@ -2579,7 +2579,7 @@ class ControlFlowTest(test.TestCase):
_, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
dy_dq, = gradients_impl.gradients(y, q)
self.assertIsNotNone(dy_dq)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
@@ -2607,7 +2607,7 @@ class ControlFlowTest(test.TestCase):
self.assertIsNotNone(grad)
def testStopGradMultiFlows(self):
- with self.test_session():
+ with self.cached_session():
def body(i, y, r):
x = variable_scope.get_variable(
@@ -2633,10 +2633,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
one = ops.convert_to_tensor(1, name="one")
two = ops.convert_to_tensor(2, name="two")
@@ -2651,10 +2651,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")
d = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2669,10 +2669,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testCase(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
z = constant_op.constant(3)
@@ -2724,10 +2724,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
def testCaseSideEffects(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
v2 = variables.Variable(-1)
@@ -2762,10 +2762,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
def testOneOpCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
one = ops.convert_to_tensor(1)
@@ -2793,7 +2793,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(2, v.eval())
def testWithOpsDependencies(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(0.0)
c = constant_op.constant(10)
@@ -2816,7 +2816,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, real_v_val)
def testWithTensorDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
c1 = constant_op.constant(10)
c2 = constant_op.constant(20)
@@ -2842,7 +2842,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, v.eval())
def testWithIndexedSlicesDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
@@ -2886,7 +2886,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
def testGroup(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v1 = variables.Variable([0.0])
v2 = variables.Variable([1.0])
@@ -2997,7 +2997,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(None, s.get_shape())
def testRunLoopTensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor_list = []
def condition(t):
@@ -3021,7 +3021,7 @@ class ControlFlowTest(test.TestCase):
def func(x):
return np.square(x)
- with self.test_session():
+ with self.cached_session():
r = control_flow_ops.while_loop(
lambda i, v: i < 4,
lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
@@ -3035,7 +3035,7 @@ class ControlFlowTest(test.TestCase):
def func(x):
return math_ops.square(math_ops.square(x))
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(2.0, dtypes.float32)
r = control_flow_ops.while_loop(
lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
@@ -3174,7 +3174,7 @@ class TupleTest(test.TestCase):
def testTensors(self):
for v1_first in [True, False]:
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([1.0])
add1 = math_ops.add(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
@@ -3204,7 +3204,7 @@ class TupleTest(test.TestCase):
def testIndexedSlices(self):
for v1_first in [True, False]:
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
@@ -3243,7 +3243,7 @@ class TupleTest(test.TestCase):
v1.eval())
def testAcceptTensorsAsControlInputs(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(0)
assign = state_ops.assign(var, 1)
t, = control_flow_ops.tuple(
@@ -3408,6 +3408,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
+@test_util.with_cond_v2
class EagerTest(test.TestCase):
def testCond(self):
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py
index fcba456004..2d6d8a8051 100644
--- a/tensorflow/python/kernel_tests/conv1d_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_test.py
@@ -53,7 +53,7 @@ class Conv1DTest(test.TestCase):
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
def testConv1DTranspose(self):
- with self.test_session():
+ with self.cached_session():
stride = 2
# Input, output: [batch, width, depth]
diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
index be299beee4..644a151710 100644
--- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv2DBackpropFilterGradTest(test.TestCase):
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 27804be65c..cbdd2c5991 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.platform import test
class Conv2DTransposeTest(test.TestCase):
def testConv2DTransposeSingleStride(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 1, 1, 1]
# Input, output: [batch, height, width, depth]
@@ -75,7 +75,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeSame(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
@@ -108,7 +108,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeValid(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
@@ -163,7 +163,7 @@ class Conv2DTransposeTest(test.TestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv2d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
index 85264ef876..89b64068ac 100644
--- a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv3DBackpropFilterV2GradTest(test.TestCase):
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index 289ae29fce..2527b83769 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv3DTransposeTest(test.TestCase):
def testConv3DTransposeSingleStride(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 1, 1, 1, 1]
# Input, output: [batch, depth, height, width, channel]
@@ -82,7 +82,7 @@ class Conv3DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeSame(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -134,7 +134,7 @@ class Conv3DTransposeTest(test.TestCase):
def testConv3DTransposeOutputShapeType(self):
# Test case for GitHub issue 18887
for dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session():
+ with self.cached_session():
x_shape = [2, 5, 6, 4, 3]
y_shape = [2, 5, 6, 4, 2]
f_shape = [3, 3, 3, 2, 3]
@@ -149,7 +149,7 @@ class Conv3DTransposeTest(test.TestCase):
output.eval()
def testConv3DTransposeValid(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -209,7 +209,7 @@ class Conv3DTransposeTest(test.TestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index 0b531125f3..6794464e3a 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -108,7 +108,7 @@ class Conv3DTest(test.TestCase):
use_gpu=use_gpu)
results.append(result)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(results)
for value in values:
print("expected = ", expected)
@@ -183,7 +183,7 @@ class Conv3DTest(test.TestCase):
expected_results.append(expected)
computed_results.append(computed)
tolerance = 1e-2 if use_gpu else 1e-5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_values = sess.run(expected_results)
computed_values = sess.run(computed_results)
for e_value, c_value in zip(expected_values, computed_values):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 00de94f004..ea611497d9 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1474,7 +1474,7 @@ class Conv2DTest(test.TestCase):
padding="SAME")
def testOpEdgeCases(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"strides in the batch and depth"):
@@ -1539,7 +1539,7 @@ class DepthwiseConv2DTest(test.TestCase):
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t1.set_shape(tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes)
diff --git a/tensorflow/python/kernel_tests/cross_grad_test.py b/tensorflow/python/kernel_tests/cross_grad_test.py
index f040ac6055..0bd4006d6a 100644
--- a/tensorflow/python/kernel_tests/cross_grad_test.py
+++ b/tensorflow/python/kernel_tests/cross_grad_test.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import test
class CrossOpTest(test.TestCase):
def testGradientRandomValues(self):
- with self.test_session():
+ with self.cached_session():
us = [2, 3]
u = array_ops.reshape(
[0.854, -0.616, 0.767, 0.725, -0.927, 0.159], shape=us)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index b61232cded..00d7f956c2 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -541,7 +541,7 @@ class UnaryOpTest(test.TestCase):
return x
for op, real_range in op_range:
- with self.test_session():
+ with self.cached_session():
for dtype, tol in dtype_tols:
x = constant_op.constant(rand(dtype))
y = constant_op.constant(rand(dtype))
@@ -604,7 +604,7 @@ class BinaryOpTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
@@ -634,7 +634,7 @@ class BinaryOpTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
@@ -720,7 +720,7 @@ class BinaryOpTest(test.TestCase):
def testFloatDifferentShapes(self):
x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
s = math_ops.reduce_sum(inx * iny)
@@ -736,7 +736,7 @@ class BinaryOpTest(test.TestCase):
y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
var_x = variables.Variable(x)
var_y = variables.Variable(y)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([var_x.initializer, var_y.initializer])
left_result = (var_x * y).eval()
right_result = (x * var_y).eval()
@@ -1168,7 +1168,7 @@ class BinaryOpTest(test.TestCase):
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
def testZeroPowGrad(self):
- with self.test_session():
+ with self.cached_session():
for dtype in (np.float16, np.float32, np.float64, np.complex64,
np.complex128):
x = constant_op.constant(0.0, dtype=dtype)
@@ -1178,7 +1178,7 @@ class BinaryOpTest(test.TestCase):
self.assertEqual(error, 0)
def testComplexPowGrad(self):
- with self.test_session():
+ with self.cached_session():
for dtype in np.complex64, np.complex128:
for base in 2.0, -2.0:
x = constant_op.constant(base, dtype=dtype)
@@ -1470,7 +1470,7 @@ class SelectOpTest(test.TestCase):
self.assertShapeEqual(np_ans, out)
def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1494,7 +1494,7 @@ class SelectOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1582,7 +1582,7 @@ class SelectOpTest(test.TestCase):
x = np.random.rand(1, 3, 0) * 100
y = np.random.rand(1, 3, 0) * 100
z_expected = np.zeros((1, 3, 0), dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
xt = x.astype(np.float32)
yt = y.astype(np.float32)
z = array_ops.where(c, xt, yt).eval()
@@ -1590,7 +1590,7 @@ class SelectOpTest(test.TestCase):
def testNan(self):
"""Verify that nans don't propagate where they shouldn't."""
- with self.test_session():
+ with self.cached_session():
for c in False, True:
for a in 7.0, np.nan:
for b in 5.0, np.nan:
@@ -1614,7 +1614,7 @@ class BatchSelectOpTest(test.TestCase):
self.assertShapeEqual(np_ans, out)
def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1638,7 +1638,7 @@ class BatchSelectOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1745,7 +1745,7 @@ class MinMaxOpTest(test.TestCase):
self._compare(x.astype(t), t(y), use_gpu=True)
def _compareGradientX(self, func, x, y):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = func(inx, iny)
@@ -1760,7 +1760,7 @@ class MinMaxOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, func, x, y):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = func(inx, iny)
@@ -1932,7 +1932,7 @@ class RoundingTest(test.TestCase):
def _compare_values(self, x, y=None):
y = np.rint(x) if y is None else np.asarray(y)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tf_rint = math_ops.rint(x)
np_rint = sess.run(tf_rint)
self.assertAllEqual(y, np_rint)
@@ -1940,7 +1940,7 @@ class RoundingTest(test.TestCase):
def _compare(self, x):
np_floor, np_ceil = np.floor(x), np.ceil(x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inx = ops.convert_to_tensor(x)
ofloor, oceil = math_ops.floor(inx), math_ops.ceil(inx)
tf_floor, tf_ceil = sess.run([ofloor, oceil])
@@ -2099,7 +2099,7 @@ class ComplexMakeRealImagTest(test.TestCase):
# computes the squared sum. This is obviously the same as sum(real
# * real) + sum(imag * imag). We just want to make sure the
# gradient function is checked.
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
real, imag = array_ops.split(value=inx, num_or_size_splits=2, axis=1)
real, imag = array_ops.reshape(real, [-1]), array_ops.reshape(imag, [-1])
@@ -2116,7 +2116,7 @@ class ComplexMakeRealImagTest(test.TestCase):
def _compareBroadcastGradient(self, x):
x_ = ops.convert_to_tensor(x)
epsilon = 1e-3
- with self.test_session():
+ with self.cached_session():
for args in [(x_, 0.), (0., x_)]:
z = math_ops.reduce_sum(math_ops.abs(math_ops.complex(*args)))
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -2136,7 +2136,7 @@ class ComplexMakeRealImagTest(test.TestCase):
# data is a float matrix of shape [n, 4]. data[:, 0], data[:, 1],
# data[:, 2], data[:, 3] are real parts of x, imaginary parts of
# x, real parts of y and imaginary parts of y.
- with self.test_session():
+ with self.cached_session():
inp = ops.convert_to_tensor(data)
xr, xi, yr, yi = array_ops.split(value=inp, num_or_size_splits=4, axis=1)
@@ -2166,7 +2166,7 @@ class ComplexMakeRealImagTest(test.TestCase):
class AccumulateTest(test.TestCase):
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
random_arrays = [
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
]
@@ -2181,20 +2181,20 @@ class AccumulateTest(test.TestCase):
self.assertAllClose(np_val, tf_val.eval())
def testZeroArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
tf_val = math_ops.accumulate_n([])
tf_val.eval()
def testWrongShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(0.2)
b = variables.Variable(0.1)
math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[]
def testWrongType(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32)
@@ -2202,7 +2202,7 @@ class AccumulateTest(test.TestCase):
def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
math_ops.accumulate_n([a], tensor_dtype=np.int32)
@@ -2214,7 +2214,7 @@ class PolyvalTest(test.TestCase):
x = np.random.rand(2, 2).astype(dtype)
coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)]
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
@@ -2237,7 +2237,7 @@ class PolyvalTest(test.TestCase):
for _ in range(degree + 1)
]
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
@@ -2245,7 +2245,7 @@ class PolyvalTest(test.TestCase):
x = np.random.rand(2, 2).astype(np.float32)
coeffs = []
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
index 35f8f76991..eebaffbe13 100644
--- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
@@ -60,7 +60,7 @@ class DecodeBmpOpTest(test.TestCase):
img_in = constant_op.constant(byte_string, dtype=dtypes.string)
decode = array_ops.squeeze(image_ops.decode_bmp(img_in))
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
@@ -135,7 +135,7 @@ class DecodeBmpOpTest(test.TestCase):
img_in = constant_op.constant(byte_string, dtype=dtypes.string)
decode = image_ops.decode_bmp(img_in)
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
diff --git a/tensorflow/python/kernel_tests/decode_compressed_op_test.py b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
index c9bda58ca7..1cc1c7da30 100644
--- a/tensorflow/python/kernel_tests/decode_compressed_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
@@ -44,7 +44,7 @@ class DecodeCompressedOpTest(test.TestCase):
def testDecompress(self):
for compression_type in ["ZLIB", "GZIP", ""]:
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
decompressed = parsing_ops.decode_compressed(
in_bytes, compression_type=compression_type)
@@ -57,7 +57,7 @@ class DecodeCompressedOpTest(test.TestCase):
def testDecompressWithRaw(self):
for compression_type in ["ZLIB", "GZIP", ""]:
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decompressed = parsing_ops.decode_compressed(
in_bytes, compression_type=compression_type)
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 4f49d72676..e9307a6b2f 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -20,28 +20,30 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class DecodeCSVOpTest(test.TestCase):
def _test(self, args, expected_out=None, expected_err_re=None):
- with self.test_session() as sess:
+ if expected_err_re is None:
decode = parsing_ops.decode_csv(**args)
-
- if expected_err_re is None:
- out = sess.run(decode)
-
- for i, field in enumerate(out):
- if field.dtype == np.float32 or field.dtype == np.float64:
- self.assertAllClose(field, expected_out[i])
- else:
- self.assertAllEqual(field, expected_out[i])
-
- else:
- with self.assertRaisesOpError(expected_err_re):
- sess.run(decode)
+ out = self.evaluate(decode)
+
+ for i, field in enumerate(out):
+ if field.dtype == np.float32 or field.dtype == np.float64:
+ self.assertAllClose(field, expected_out[i])
+ else:
+ self.assertAllEqual(field, expected_out[i])
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ decode = parsing_ops.decode_csv(**args)
+ self.evaluate(decode)
def testSimple(self):
args = {
@@ -53,6 +55,31 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out)
+ def testSimpleWithScalarDefaults(self):
+ args = {
+ "records": ["1,4", "2,5", "3,6"],
+ "record_defaults": [1, 2],
+ }
+
+ expected_out = [[1, 2, 3], [4, 5, 6]]
+
+ self._test(args, expected_out)
+
+ def testSimpleWith2DDefaults(self):
+ args = {
+ "records": ["1", "2", "3"],
+ "record_defaults": [[[0]]],
+ }
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ "Each record default should be at "
+ "most rank 1.")
+ else:
+ err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test(args)
+
def testSimpleNoQuoteDelimiter(self):
args = {
"records": ["1", "2", '"3"'],
diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py
index 58280432d6..7f73fbaa84 100644
--- a/tensorflow/python/kernel_tests/decode_image_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_image_op_test.py
@@ -111,7 +111,7 @@ class DecodeImageOpTest(test.TestCase):
def testInvalidBytes(self):
image_bytes = b"ThisIsNotAnImage!"
decode = image_ops.decode_image(image_bytes)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
decode.eval()
diff --git a/tensorflow/python/kernel_tests/decode_png_op_test.py b/tensorflow/python/kernel_tests/decode_png_op_test.py
index d2e03938ee..8f36343667 100644
--- a/tensorflow/python/kernel_tests/decode_png_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_png_op_test.py
@@ -46,7 +46,7 @@ class DecodePngOpTest(test.TestCase):
image_ops.decode_png(
img_in, dtype=dtypes.uint16))
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
index 122a9ed469..dcc984811c 100644
--- a/tensorflow/python/kernel_tests/decode_raw_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class DecodeRawOpTest(test.TestCase):
def testToUint8(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint8)
self.assertEqual([2, None], decode.get_shape().as_list())
@@ -47,7 +47,7 @@ class DecodeRawOpTest(test.TestCase):
decode.eval(feed_dict={in_bytes: ["short", "longer"]})
def testToInt16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.int16)
self.assertEqual([None, None], decode.get_shape().as_list())
@@ -62,7 +62,7 @@ class DecodeRawOpTest(test.TestCase):
decode.eval(feed_dict={in_bytes: ["123", "456"]})
def testEndianness(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode_le = parsing_ops.decode_raw(
in_bytes, out_type=dtypes.int32, little_endian=True)
@@ -74,18 +74,18 @@ class DecodeRawOpTest(test.TestCase):
self.assertAllEqual([[0x01020304]], result)
def testToFloat16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
self.assertEqual([None, None], decode.get_shape().as_list())
- expected_result = np.matrix([[1, -2, -3, 4]], dtype=np.float16)
+ expected_result = np.matrix([[1, -2, -3, 4]], dtype="<f2")
result = decode.eval(feed_dict={in_bytes: [expected_result.tostring()]})
self.assertAllEqual(expected_result, result)
def testEmptyStringInput(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
@@ -94,7 +94,7 @@ class DecodeRawOpTest(test.TestCase):
self.assertEqual((num_inputs, 0), result.shape)
def testToUInt16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
self.assertEqual([None, None], decode.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
index d33bf1ba12..affbaf159d 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -33,7 +33,7 @@ class AssignOpTest(test.TestCase):
# contain benign and deliberate data races when multiple threads update
# the same parameters without a lock.
def testParallelUpdateWithoutLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(array_ops.zeros([1024, 1024]))
adds = [
@@ -60,7 +60,7 @@ class AssignOpTest(test.TestCase):
self.assertTrue((vals <= ones * 20).all())
def testParallelAssignWithoutLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], float(1))
p = variables.Variable(array_ops.zeros([1024, 1024]))
assigns = [
@@ -92,7 +92,7 @@ class AssignOpTest(test.TestCase):
# returning the output tensors. This issue will be resolved with the new
# resource variables.
def testParallelUpdateWithLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(zeros_t)
@@ -119,7 +119,7 @@ class AssignOpTest(test.TestCase):
self.assertAllEqual(vals, ones * 20)
def testParallelAssignWithLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(zeros_t)
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 4dda9f093b..06c3271850 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -85,7 +85,7 @@ class AssignOpTest(test.TestCase):
self._testTypes(np.arange(0, 20).reshape([4, 5]))
def testAssignNonStrictShapeChecking(self):
- with self.test_session():
+ with self.cached_session():
data = array_ops.fill([1024, 1024], 0)
p = variables.Variable([1])
a = state_ops.assign(p, data, validate_shape=False)
@@ -99,14 +99,14 @@ class AssignOpTest(test.TestCase):
self.assertAllEqual(p.eval(), data2.eval())
def testInitRequiredAssignAdd(self):
- with self.test_session():
+ with self.cached_session():
p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
def testInitRequiredAssignSub(self):
- with self.test_session():
+ with self.cached_session():
p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 58845552db..5741f2ec64 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -205,6 +205,19 @@ class DepthwiseConv2DTest(test.TestCase):
use_gpu=True,
grouped_conv=True)
+ def testDepthwiseConv2DWithUnknownShape(self):
+ # GitHub issue 22110.
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True):
+ x = array_ops.placeholder(dtypes.float32)
+ f = np.ones([1, 1, 1, 1], np.float32)
+ v = nn_impl.depthwise_conv2d(
+ x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW")
+ self.assertAllEqual(
+ np.ones([1, 1, 1, 1], np.float32),
+ v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)}))
+
def testDepthwiseConv2DFormat(self):
if not test.is_gpu_available():
return
diff --git a/tensorflow/python/kernel_tests/division_future_test.py b/tensorflow/python/kernel_tests/division_future_test.py
index e681b32856..e477bdc73b 100644
--- a/tensorflow/python/kernel_tests/division_future_test.py
+++ b/tensorflow/python/kernel_tests/division_future_test.py
@@ -50,7 +50,7 @@ class DivisionTestCase(test.TestCase):
self.assertEqual(x, y)
checks.append(f)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for dtype in dtypes:
for x in map(dtype, values):
for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/division_past_test.py b/tensorflow/python/kernel_tests/division_past_test.py
index 9ddd62e63c..63951b5b38 100644
--- a/tensorflow/python/kernel_tests/division_past_test.py
+++ b/tensorflow/python/kernel_tests/division_past_test.py
@@ -49,7 +49,7 @@ class DivisionTestCase(test.TestCase):
self.assertEqual(x, y)
checks.append(f)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for dtype in dtypes:
for x in map(dtype, values):
for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py
index 529d3dd0b3..654267a582 100644
--- a/tensorflow/python/kernel_tests/duplicate_op_test.py
+++ b/tensorflow/python/kernel_tests/duplicate_op_test.py
@@ -34,7 +34,7 @@ class DuplicateOpTest(test.TestCase):
self.assertEqual(len(duplicate.OP_LIST.op), 0)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(math_ops.add(1, 41).eval(), 42)
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 5e8937ad2c..9557e30993 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -288,7 +288,7 @@ class DynamicPartitionTest(test.TestCase):
self.assertAllEqual([], partition_vals[i])
def testErrorIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
indices = constant_op.constant([0, 2, 99, 2, 2])
@@ -298,7 +298,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testScalarIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bad = 17
data = np.zeros(5)
partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
@@ -306,7 +306,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testHigherRankIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = (2, 3)
indices = array_ops.placeholder(shape=shape, dtype=np.int32)
data = np.zeros(shape + (5,))
@@ -334,7 +334,7 @@ class DynamicPartitionTest(test.TestCase):
inds += [13]*194 + [14]*194 + [15]*192
self.assertEqual(len(inds), x.shape[0])
partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
res = sess.run(partitioned)
self.assertEqual(res[-1].shape[0], 192)
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 49b9569e2b..3a1036e52a 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -252,7 +252,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
# GPU version unit tests
def testScalarGPU(self):
- with self.test_session():
+ with self.cached_session():
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
@@ -263,7 +263,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
self.assertEqual([2], stitched_t.get_shape().as_list())
def testHigherRankGPU(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = [
constant_op.constant(6),
constant_op.constant([4, 1]),
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index dcd435e1ff..40b8548cea 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -242,7 +242,7 @@ class EmbeddingLookupTest(test.TestCase):
# vector is going to be empty. The subsequent DivOp fails because of that.
# TODO(keveman): Disabling the test until the underlying problem is fixed.
def testSimpleSharded(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 2
vocab_size = 4
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
@@ -258,7 +258,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testMaxNorm(self):
- with self.test_session():
+ with self.cached_session():
embeddings = constant_op.constant([[2.0]])
ids = constant_op.constant([0], dtype=dtypes.int32)
@@ -268,7 +268,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertAllEqual(embedding.eval(), [[1.0]])
def testMaxNormNontrivial(self):
- with self.test_session():
+ with self.cached_session():
embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
ids = constant_op.constant([0, 1], dtype=dtypes.int32)
@@ -281,7 +281,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertAllEqual(embedding.eval(), 2 * normalized.eval())
def testSimpleShardedPartitionedVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_shards = 2
vocab_size = 4
p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
@@ -303,7 +303,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testSimpleShardedPartitionedResourceVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_shards = 2
vocab_size = 4
p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable(
@@ -326,7 +326,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedModPartitioningInt32Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -348,7 +348,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedModPartitioningInt64Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -370,7 +370,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt32Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -394,7 +394,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt32IdsPartitionedVariable(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -419,7 +419,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt64Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -443,7 +443,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningUnknownParamShape(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -475,7 +475,7 @@ class EmbeddingLookupTest(test.TestCase):
tf_logging.vlog(1, id_vals)
for ids_shape in [(10,), (2, 5)]:
for num_shards in [1, 3]:
- with self.test_session():
+ with self.cached_session():
ids = constant_op.constant(
id_vals, shape=ids_shape, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
@@ -494,7 +494,7 @@ class EmbeddingLookupTest(test.TestCase):
id_vals = list(np.random.randint(vocab_size, size=num_ids))
tf_logging.vlog(1, id_vals)
for num_shards in [1, 3]:
- with self.test_session():
+ with self.cached_session():
ids = constant_op.constant(id_vals, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
# This will force a conversion from IndexedSlices to Tensor.
@@ -528,7 +528,7 @@ class EmbeddingLookupTest(test.TestCase):
def testHigherRank(self):
np.random.seed(8)
- with self.test_session():
+ with self.cached_session():
for params_shape in (12,), (6, 3):
params = np.random.randn(*params_shape)
for ids_shape in (3, 2), (4, 3):
@@ -548,7 +548,7 @@ class EmbeddingLookupTest(test.TestCase):
def testHigherRankMaxNorm(self):
np.random.seed(8)
- with self.test_session():
+ with self.cached_session():
for params_shape in (12,), (6, 3), (6, 2, 3):
# Test embedding rank 0, 1, 2.
# Note: the first dimension must be a common multiple of procs below.
@@ -581,7 +581,7 @@ class EmbeddingLookupTest(test.TestCase):
# It always applies max_norm.
np.random.seed(8)
l2_norm = 2.
- with self.test_session():
+ with self.cached_session():
# Param values are in [l2_norm, l2_norm+1) so it will always clip.
params = np.random.rand(6, 3) + l2_norm
params_norm = l2_norm * params / np.sqrt(
@@ -667,7 +667,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
[dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
[True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = embedding_ops.embedding_lookup_sparse(
@@ -716,7 +716,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -734,7 +734,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
@@ -819,7 +819,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -832,7 +832,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -846,7 +846,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -860,7 +860,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -874,7 +874,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -889,7 +889,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -902,7 +902,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -918,7 +918,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -934,7 +934,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -951,7 +951,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -1035,7 +1035,7 @@ class DynamicStitchOpTest(test.TestCase):
# We expect that the values are merged in order.
def testStitchOrder(self):
- with self.test_session():
+ with self.cached_session():
indices = []
np_values = []
values = []
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index e1f5a6b620..7d9d4e5175 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -83,7 +83,7 @@ class ExtractImagePatchesGradTest(test.TestCase):
random_seed = 42
random_seed_lib.set_random_seed(random_seed)
- with self.test_session():
+ with self.cached_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index 629acedda5..f117934e4b 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -496,7 +496,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
"Input dimension .* must have length of at least 6 but got: 5"):
x = np.zeros((5,) * rank).astype(np.float32)
fft_length = [6] * rank
- with self.test_session():
+ with self.cached_session():
rfft_fn(x, fft_length).eval()
with self.assertRaisesWithPredicateMatch(
@@ -504,7 +504,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
"Input dimension .* must have length of at least .* but got: 3"):
x = np.zeros((3,) * rank).astype(np.complex64)
fft_length = [6] * rank
- with self.test_session():
+ with self.cached_session():
irfft_fn(x, fft_length).eval()
def testGrad_Simple(self):
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 9e7b528338..a5f8f64e0c 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -99,19 +99,19 @@ class FIFOQueueTest(test.TestCase):
""", q.queue_ref.op.node_def)
def testEnqueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueHalf(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
@@ -120,7 +120,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(1, q.size().eval())
def testEnqueueManyWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(
10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
@@ -143,7 +143,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})
@@ -151,7 +151,7 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many({"a": [12.0, 13.0]})
def testParallelEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -177,7 +177,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -201,7 +201,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -215,7 +215,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testDequeueHalf(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -229,7 +229,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -259,7 +259,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@@ -275,12 +275,12 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
@@ -293,7 +293,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, size.eval())
def testEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -306,7 +306,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i % 4]], vals)
def testEmptyEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
empty_t = constant_op.constant(
[], dtype=dtypes_lib.float32, shape=[0, 2, 3])
@@ -318,7 +318,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([0], size_t.eval())
def testEmptyDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_many(0)
@@ -328,7 +328,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueUpTo(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_up_to(0)
@@ -338,14 +338,14 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueManyWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
# Expect the operation to fail due to the shape not being constrained.
with self.assertRaisesOpError("specified shapes"):
q.dequeue_many(0).eval()
def testMultiEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.float32, dtypes_lib.int32))
float_elems = [10.0, 20.0, 30.0, 40.0]
int_elems = [[1, 2], [3, 4], [5, 6], [7, 8]]
@@ -361,7 +361,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testDequeueUpToNoBlocking(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -385,7 +385,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testMultiDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -416,7 +416,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
def testMultiDequeueUpToNoBlocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -440,7 +440,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[4:8], int_val)
def testHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, (4, 4, 4, 4))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
enqueue_op = q.enqueue_many((elems,))
@@ -494,7 +494,7 @@ class FIFOQueueTest(test.TestCase):
array_ops.placeholder(dtypes_lib.int32)))
def testEnqueueWrongShapeAtRuntime(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
(2, 2), (3, 3)))
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
@@ -506,7 +506,7 @@ class FIFOQueueTest(test.TestCase):
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
(2, 2), (3, 3)))
elems_ok = np.array([1] * 8).reshape((2, 2, 2)).astype(np.int32)
@@ -521,7 +521,7 @@ class FIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testParallelEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(100)]
enqueue_op = q.enqueue_many((elems,))
@@ -540,7 +540,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -562,7 +562,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -586,7 +586,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(50, dtypes_lib.float32, shapes=())
initial_elements = [10.0] * 49
q.enqueue_many((initial_elements,)).run()
@@ -619,7 +619,7 @@ class FIFOQueueTest(test.TestCase):
self.assertTrue(elem in (10.0, 20.0))
def testMixtureOfEnqueueAndEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -655,7 +655,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testMixtureOfDequeueAndDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
dequeued_t = q.dequeue()
@@ -689,7 +689,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -716,7 +716,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -743,7 +743,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testDequeueManyWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.FIFOQueue(100, dtypes_lib.int32, ())
@@ -768,7 +768,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(elems, dequeued_elems)
def testDequeueFromClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -786,7 +786,7 @@ class FIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testBlockingDequeueFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -812,7 +812,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
close_op = q.close()
dequeued_t = q.dequeue()
@@ -832,7 +832,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -857,7 +857,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -882,7 +882,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testDequeueUpToFromClosedQueueReturnsRemainder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -904,7 +904,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -941,7 +941,7 @@ class FIFOQueueTest(test.TestCase):
close_thread.join()
def testClosedBlockingDequeueManyRestoresPartialBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, (dtypes_lib.float32, dtypes_lib.float32), (
(), ()))
elems_a = [1.0, 2.0, 3.0]
@@ -974,7 +974,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueManyFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
close_op = q.close()
dequeued_t = q.dequeue_many(4)
@@ -994,7 +994,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
close_op = q.close()
dequeued_t = q.dequeue_up_to(4)
@@ -1014,7 +1014,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
close_op = q.close()
@@ -1027,7 +1027,7 @@ class FIFOQueueTest(test.TestCase):
enqueue_op.run()
def testEnqueueManyToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1041,7 +1041,7 @@ class FIFOQueueTest(test.TestCase):
enqueue_op.run()
def testBlockingEnqueueToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1064,7 +1064,7 @@ class FIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueManyToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1091,7 +1091,7 @@ class FIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1128,7 +1128,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingEnqueueManyBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1161,7 +1161,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(elem, dequeued_t.eval())
def testDoesNotLoseValue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(1, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
size_t = q.size()
@@ -1171,7 +1171,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(size_t.eval(), [1])
def testSharedQueueSameSession(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.FIFOQueue(
1, dtypes_lib.float32, shared_name="shared_queue")
q1.enqueue((10.0,)).run()
@@ -1201,7 +1201,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), [0])
def testIncompatibleSharedQueueErrors(self):
- with self.test_session():
+ with self.cached_session():
q_a_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_a")
q_a_2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32, shared_name="q_a")
q_a_1.queue_ref.op.run()
@@ -1244,7 +1244,7 @@ class FIFOQueueTest(test.TestCase):
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
- with self.test_session():
+ with self.cached_session():
num_queues = 10
qlist = list()
for _ in xrange(num_queues):
@@ -1257,7 +1257,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(q.dequeue().eval(), 10.0)
def testSelectQueueOutOfRange(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
q2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32)
enq_q = data_flow_ops.FIFOQueue.from_list(3, [q1, q2])
@@ -1281,7 +1281,7 @@ class FIFOQueueTest(test.TestCase):
sess.run(enqueue_many_op)
def testResetOfBlockingOperation(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_empty = data_flow_ops.FIFOQueue(5, dtypes_lib.float32, ())
dequeue_op = q_empty.dequeue()
dequeue_many_op = q_empty.dequeue_many(1)
@@ -1309,7 +1309,7 @@ class FIFOQueueTest(test.TestCase):
t.join()
def testBigEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(5, dtypes_lib.int32, ((),))
elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
enq = q.enqueue_many((elem,))
@@ -1354,7 +1354,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testBigDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(2, dtypes_lib.int32, ((),))
elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
@@ -1380,7 +1380,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
@@ -1411,7 +1411,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(input_elem, output_elem)
def testDequeueEnqueueFail(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
a = q.dequeue()
b = control_flow_ops.Assert(False, ["Before enqueue"])
@@ -1474,7 +1474,7 @@ class FIFOQueueDictTest(test.TestCase):
self.assertEqual(["i", "f"], q.names)
def testEnqueueDequeueOneComponent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=((),), names="f")
# Verify that enqueue() checks that when using names we must enqueue a
@@ -1519,7 +1519,7 @@ class FIFOQueueDictTest(test.TestCase):
self.assertEqual([40.0, 50.0], list(f))
def testEnqueueDequeueMultipleComponent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32, dtypes_lib.string),
shapes=((), (), ()),
@@ -1600,7 +1600,7 @@ class FIFOQueueWithTimeoutTest(test.TestCase):
sess.run(dequeued_t)
def testReusableAfterTimeout(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
dequeued_t = q.dequeue()
enqueue_op = q.enqueue(37)
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
index faac7d8365..f89d2062f1 100644
--- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -127,7 +127,7 @@ class FractionalAvgTest(test.TestCase):
Returns:
None
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_avg_pool(
input_tensor,
pooling_ratio,
@@ -160,7 +160,7 @@ class FractionalAvgTest(test.TestCase):
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_avg_pool(
rand_mat.astype(np.float32),
pooling_ratio,
@@ -234,7 +234,7 @@ class FractionalAvgTest(test.TestCase):
[4, 4, 5, 9, 7, 2]
])
# pyformat: enable
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
# are the same each time. We can have an expected result precomputed.
# r = [0, 2, 4, 6]
@@ -314,7 +314,7 @@ class FractionalAvgTest(test.TestCase):
def testDifferentInputTensorShape(self):
"""Runs the operation in one session with different input tensor shapes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_holder = array_ops.placeholder(dtypes.float32,
[None, None, None, 3])
pooling_ratio = [1, 1.5, 1.5, 1]
@@ -389,7 +389,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateRandomInputTensor(input_shape).astype(
np.float32))
@@ -428,7 +428,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateRandomInputTensor(input_shape).astype(
np.float32))
@@ -468,7 +468,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
for pseudo_random in True, False:
for overlapping in True, False:
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
@@ -501,7 +501,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
for num_channels in [1, 3]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
input_data = self._GenerateRandomInputTensor(input_shape)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
@@ -532,7 +532,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
overlapping = True
pseudo_random = False
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
index 6477c9ebc4..9b94ca8554 100644
--- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -127,7 +127,7 @@ class FractionalMaxPoolTest(test.TestCase):
Returns:
None
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_max_pool(
input_tensor,
pooling_ratio,
@@ -160,7 +160,7 @@ class FractionalMaxPoolTest(test.TestCase):
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_max_pool(
rand_mat,
pooling_ratio,
@@ -285,7 +285,7 @@ class FractionalMaxPoolTest(test.TestCase):
def testDifferentInputTensorShape(self):
"""Runs the operation in one session with different input tensor shapes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_holder = array_ops.placeholder(dtypes.float32,
[None, None, None, 3])
pooling_ratio = [1, 1.5, 1.5, 1]
@@ -374,7 +374,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateUniqueRandomInputTensor(input_shape))
window_size = [1, row_window_size, col_window_size, 1]
@@ -409,7 +409,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateUniqueRandomInputTensor(input_shape))
window_size = [1, row_window_size, col_window_size, 1]
@@ -447,7 +447,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
for pseudo_random in True, False:
for overlapping in True, False:
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -482,7 +482,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
# Add some randomness to make input_data not so 'integer'
input_data += self._PRNG.random_sample(input_shape)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -515,7 +515,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
overlapping = True
pseudo_random = False
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -579,7 +579,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
0.0, 0.0, 0.0, 0.0,
6.0, 0.0, 21.0, 0.0],
input_size) # pyformat: disable
- with self.test_session() as _:
+ with self.cached_session() as _:
# Test when overlapping is False
input_tensor = constant_op.constant(input_data, shape=input_size)
output_tensor = constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 033fa95935..85bf969068 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -147,7 +147,7 @@ class GatherTest(test.TestCase):
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([b"qwer", b"uiop"],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([b"asdf", b"qwer"],
@@ -157,7 +157,7 @@ class GatherTest(test.TestCase):
for unsigned_type in (dtypes.uint32, dtypes.uint64):
params = self._buildParams(
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([7, 8, 9],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
diff --git a/tensorflow/python/kernel_tests/gradient_correctness_test.py b/tensorflow/python/kernel_tests/gradient_correctness_test.py
index e93c6235f7..291a69ebac 100644
--- a/tensorflow/python/kernel_tests/gradient_correctness_test.py
+++ b/tensorflow/python/kernel_tests/gradient_correctness_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class GradientCorrectnessTest(test.TestCase):
def testMultipleOutputChainedGradients(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(1.0, dtype=dtypes.float32)
yexp = math_ops.exp(x)
yexplog = math_ops.log(yexp)
@@ -43,13 +43,13 @@ class GradientCorrectnessTest(test.TestCase):
def testIdentityGradient(self):
x = constant_op.constant(3.)
dx_dx, = gradients_impl.gradients(x, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose(1., sess.run(dx_dx))
def testIntegerIdentityGradient(self):
x = constant_op.constant(3)
dx_dx, = gradients_impl.gradients(x, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose(1, sess.run(dx_dx))
def testGradientWithIntegerPath(self):
@@ -57,7 +57,7 @@ class GradientCorrectnessTest(test.TestCase):
k = math_ops.to_float(math_ops.to_int32(x))
y = x * k
dy_dx, = gradients_impl.gradients(y, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose([3., 4.], sess.run(dy_dx))
def testNoIntegerGradient1(self):
diff --git a/tensorflow/python/kernel_tests/identity_n_op_py_test.py b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
index 408b173981..518733cd8e 100644
--- a/tensorflow/python/kernel_tests/identity_n_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class IdentityNOpTest(test.TestCase):
def testInt32String_6(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[value0, value1] = sess.run(
array_ops.identity_n([[1, 2, 3, 4, 5, 6],
[b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))
@@ -37,7 +37,7 @@ class IdentityNOpTest(test.TestCase):
np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
def testInt32_shapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
inp1 = constant_op.constant([11, 21, 31, 41, 51, 61], shape=[3, 2])
inp2 = constant_op.constant(
@@ -52,12 +52,12 @@ class IdentityNOpTest(test.TestCase):
def testString(self):
source = [b"A", b"b", b"C", b"d", b"E", b"f"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[value] = sess.run(array_ops.identity_n([source]))
self.assertAllEqual(source, value)
def testIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
array_2x3 = [[1, 2, 3], [6, 5, 4]]
tensor = constant_op.constant(array_2x3)
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 49fb76d5b4..37f9f716f8 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -31,24 +31,24 @@ from tensorflow.python.platform import test
class IdentityOpTest(test.TestCase):
def testInt32_6(self):
- with self.test_session():
+ with self.cached_session():
value = array_ops.identity([1, 2, 3, 4, 5, 6]).eval()
self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value)
def testInt32_2_3(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
value = array_ops.identity(inp).eval()
self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value)
def testString(self):
source = [b"A", b"b", b"C", b"d", b"E", b"f"]
- with self.test_session():
+ with self.cached_session():
value = array_ops.identity(source).eval()
self.assertAllEqual(source, value)
def testIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
array_2x3 = [[1, 2, 3], [6, 5, 4]]
tensor = constant_op.constant(array_2x3)
@@ -59,7 +59,7 @@ class IdentityOpTest(test.TestCase):
array_ops.identity(np.array(array_2x3)).get_shape())
def testRefIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
tensor = variables.Variable(
constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index fafeea8ec0..6fdb497bc6 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -30,7 +30,7 @@ class InTopKTest(test.TestCase):
def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected)
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
@@ -65,7 +65,7 @@ class InTopKTest(test.TestCase):
def testBadTarget(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 80000]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"target.*out of range"):
nn_ops.in_top_k(predictions, target, 2).eval()
@@ -75,7 +75,7 @@ class InTopKTest(test.TestCase):
target = [0, 2]
k = constant_op.constant(3)
np_ans = np.array([False, True])
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index f6097ad489..79ce965242 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -343,7 +343,7 @@ class UniformUnitScalingInitializationTest(test.TestCase):
def testZeroSize(self):
shape = [0, 2]
- with self.test_session():
+ with self.cached_session():
x = variable_scope.get_variable(
"x",
shape=shape,
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
index 6e894365af..90759c23ae 100644
--- a/tensorflow/python/kernel_tests/inplace_ops_test.py
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -153,7 +153,7 @@ class InplaceOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(vy, vz)
def testError(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must be a vector"):
_ = inplace_ops.inplace_update([[1.]], [[0]], [[10]]).eval()
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index 61944f7e31..afa24195cb 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -37,7 +37,7 @@ class IoOpsTest(test.TestCase):
with tempfile.NamedTemporaryFile(
prefix='ReadFileTest', dir=self.get_temp_dir(), delete=False) as temp:
temp.write(contents)
- with self.test_session():
+ with self.cached_session():
read = io_ops.read_file(temp.name)
self.assertEqual([], read.get_shape())
self.assertEqual(read.eval(), contents)
@@ -51,7 +51,7 @@ class IoOpsTest(test.TestCase):
prefix='WriteFileTest', dir=self.get_temp_dir(),
delete=False) as temp:
pass
- with self.test_session() as sess:
+ with self.cached_session() as sess:
w = io_ops.write_file(temp.name, contents)
sess.run(w)
with open(temp.name, 'rb') as f:
@@ -65,7 +65,7 @@ class IoOpsTest(test.TestCase):
contents = compat.as_bytes(contents)
subdir = os.path.join(self.get_temp_dir(), 'subdir1')
filepath = os.path.join(subdir, 'subdir2', 'filename')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
w = io_ops.write_file(filepath, contents)
sess.run(w)
with open(filepath, 'rb') as f:
@@ -88,7 +88,7 @@ class IoOpsTest(test.TestCase):
prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases
]
- with self.test_session():
+ with self.cached_session():
# Test exact match without wildcards.
for f in files:
self.assertEqual(
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index f4ec3e3996..be2e31cb5a 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -25,6 +25,22 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_addition_test",
+ size = "small",
+ srcs = ["linear_operator_addition_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "linear_operator_block_diag_test",
size = "medium",
srcs = ["linear_operator_block_diag_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
new file mode 100644
index 0000000000..7c79fedf65
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -0,0 +1,412 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_addition
+from tensorflow.python.platform import test
+
+linalg = linalg_lib
+random_seed.set_random_seed(23)
+rng = np.random.RandomState(0)
+
+add_operators = linear_operator_addition.add_operators
+
+
+# pylint: disable=unused-argument
+class _BadAdder(linear_operator_addition._Adder):
+ """Adder that will fail if used."""
+
+ def can_add(self, op1, op2):
+ raise AssertionError("BadAdder.can_add called!")
+
+ def _add(self, op1, op2, operator_name, hints):
+ raise AssertionError("This line should not be reached")
+
+
+# pylint: enable=unused-argument
+
+
+class LinearOperatorAdditionCorrectnessTest(test.TestCase):
+ """Tests correctness of addition with combinations of a few Adders.
+
+ Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
+ add_operators should reduce all operators resulting in one single operator.
+
+ This shows that we are able to correctly combine adders using the tiered
+ system. All Adders should be tested separately, and there is no need to test
+ every Adder within this class.
+ """
+
+ def test_one_operator_is_returned_unchanged(self):
+ op_a = linalg.LinearOperatorDiag([1., 1.])
+ op_sum = add_operators([op_a])
+ self.assertEqual(1, len(op_sum))
+ self.assertIs(op_sum[0], op_a)
+
+ def test_at_least_one_operators_required(self):
+ with self.assertRaisesRegexp(ValueError, "must contain at least one"):
+ add_operators([])
+
+ def test_attempting_to_add_numbers_raises(self):
+ with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
+ add_operators([1, 2])
+
+ def test_two_diag_operators(self):
+ op_a = linalg.LinearOperatorDiag(
+ [1., 1.], is_positive_definite=True, name="A")
+ op_b = linalg.LinearOperatorDiag(
+ [2., 2.], is_positive_definite=True, name="B")
+ with self.test_session():
+ op_sum = add_operators([op_a, op_b])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorDiag)
+ self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
+ # Adding positive definite operators produces positive def.
+ self.assertTrue(op.is_positive_definite)
+ # Real diagonal ==> self-adjoint.
+ self.assertTrue(op.is_self_adjoint)
+ # Positive definite ==> non-singular
+ self.assertTrue(op.is_non_singular)
+ # Enforce particular name for this simple case
+ self.assertEqual("Add/B__A/", op.name)
+
+ def test_three_diag_operators(self):
+ op1 = linalg.LinearOperatorDiag(
+ [1., 1.], is_positive_definite=True, name="op1")
+ op2 = linalg.LinearOperatorDiag(
+ [2., 2.], is_positive_definite=True, name="op2")
+ op3 = linalg.LinearOperatorDiag(
+ [3., 3.], is_positive_definite=True, name="op3")
+ with self.test_session():
+ op_sum = add_operators([op1, op2, op3])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
+ self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+ # Adding positive definite operators produces positive def.
+ self.assertTrue(op.is_positive_definite)
+ # Real diagonal ==> self-adjoint.
+ self.assertTrue(op.is_self_adjoint)
+ # Positive definite ==> non-singular
+ self.assertTrue(op.is_non_singular)
+
+ def test_diag_tril_diag(self):
+ op1 = linalg.LinearOperatorDiag(
+ [1., 1.], is_non_singular=True, name="diag_a")
+ op2 = linalg.LinearOperatorLowerTriangular(
+ [[2., 0.], [0., 2.]],
+ is_self_adjoint=True,
+ is_non_singular=True,
+ name="tril")
+ op3 = linalg.LinearOperatorDiag(
+ [3., 3.], is_non_singular=True, name="diag_b")
+ with self.test_session():
+ op_sum = add_operators([op1, op2, op3])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular)
+ self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+
+ # The diag operators will be self-adjoint (because real and diagonal).
+ # The TriL operator has the self-adjoint hint set.
+ self.assertTrue(op.is_self_adjoint)
+
+ # Even though op1/2/3 are non-singular, this does not imply op is.
+ # Since no custom hint was provided, we default to None (unknown).
+ self.assertEqual(None, op.is_non_singular)
+
+ def test_matrix_diag_tril_diag_uses_custom_name(self):
+ op0 = linalg.LinearOperatorFullMatrix(
+ [[-1., -1.], [-1., -1.]], name="matrix")
+ op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
+ op2 = linalg.LinearOperatorLowerTriangular(
+ [[2., 0.], [1.5, 2.]], name="tril")
+ op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
+ with self.test_session():
+ op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix)
+ self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
+ self.assertEqual("my_operator", op.name)
+
+ def test_incompatible_domain_dimensions_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
+ with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
+ add_operators([op1, op2])
+
+ def test_incompatible_range_dimensions_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
+ with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
+ add_operators([op1, op2])
+
+ def test_non_broadcastable_batch_shape_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
+ with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
+ add_operators([op1, op2])
+
+
+class LinearOperatorOrderOfAdditionTest(test.TestCase):
+ """Test that the order of addition is done as specified by tiers."""
+
+ def test_tier_0_additions_done_in_tier_0(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ diag3 = linalg.LinearOperatorDiag([1.])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ ]
+ # Should not raise since all were added in tier 0, and tier 1 (with the
+ # _BadAdder) was never reached.
+ op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
+
+ def test_tier_1_additions_done_by_tier_1(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [linear_operator_addition._AddAndReturnTriL()],
+ [_BadAdder()],
+ ]
+ # Should not raise since all were added by tier 1, and the
+ # _BadAdder) was never reached.
+ op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+ def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnTriL()],
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ ]
+ # Tier 0 could convert to TriL, and this converted everything to TriL,
+ # including the Diags.
+ # Tier 1 was never used.
+ # Tier 2 was never used (therefore, _BadAdder didn't raise).
+ op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+ def test_cannot_add_everything_so_return_more_than_one_operator(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([2.])
+ tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ ]
+ # Tier 0 (the only tier) can only convert to Diag, so it combines the two
+ # diags, but the TriL is unchanged.
+ # Result should contain two operators, one Diag, one TriL.
+ op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
+ self.assertEqual(2, len(op_sum))
+ found_diag = False
+ found_tril = False
+ with self.test_session():
+ for op in op_sum:
+ if isinstance(op, linalg.LinearOperatorDiag):
+ found_diag = True
+ self.assertAllClose([[3.]], op.to_dense().eval())
+ if isinstance(op, linalg.LinearOperatorLowerTriangular):
+ found_tril = True
+ self.assertAllClose([[5.]], op.to_dense().eval())
+ self.assertTrue(found_diag and found_tril)
+
+ def test_intermediate_tier_is_not_skipped(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ [linear_operator_addition._AddAndReturnTriL()],
+ ]
+ # tril cannot be added in tier 0, and the intermediate tier 1 with the
+ # BadAdder will catch it and raise.
+ with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
+ add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+
+
+class AddAndReturnScaledIdentityTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
+
+ def test_identity_plus_identity(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2)
+ id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_identity_plus_scaled_identity(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(3.2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_scaled_identity_plus_scaled_identity(self):
+ id1 = linalg.LinearOperatorScaledIdentity(
+ num_rows=2, multiplier=[2.2, 2.2, 2.2])
+ id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(1.2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnDiagTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnDiag()
+
+ def test_identity_plus_identity_returns_diag(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2)
+ id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+ with self.test_session():
+ self.assertAllClose(2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_diag_plus_diag(self):
+ diag1 = rng.rand(2, 3, 4)
+ diag2 = rng.rand(4)
+ op1 = linalg.LinearOperatorDiag(diag1)
+ op2 = linalg.LinearOperatorDiag(diag2)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(op1, op2))
+ operator = self._adder.add(op1, op2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+ with self.test_session():
+ self.assertAllClose(
+ linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnTriLTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnTriL()
+
+ def test_diag_plus_tril(self):
+ diag = linalg.LinearOperatorDiag([1., 2.])
+ tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(diag, diag))
+ self.assertTrue(self._adder.can_add(diag, tril))
+ operator = self._adder.add(diag, tril, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
+
+ with self.test_session():
+ self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnMatrixTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnMatrix()
+
+ def test_diag_plus_diag(self):
+ diag1 = linalg.LinearOperatorDiag([1., 2.])
+ diag2 = linalg.LinearOperatorDiag([-1., 3.])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=False, is_non_singular=False)
+
+ self.assertTrue(self._adder.can_add(diag1, diag2))
+ operator = self._adder.add(diag1, diag2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
+
+ with self.test_session():
+ self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
+ self.assertFalse(operator.is_positive_definite)
+ self.assertFalse(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 0e4e58409e..cd6a34d657 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -40,7 +40,7 @@ def _AddTest(test, op_name, testcase_name, fn):
class ShapeTest(test_lib.TestCase):
def testBatchGradientUnknownSize(self):
- with self.test_session():
+ with self.cached_session():
batch_size = constant_op.constant(3)
matrix_size = constant_op.constant(4)
batch_identity = array_ops.tile(
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 2f28d37eff..aa17f727d0 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -128,7 +128,7 @@ class AdjointTest(test.TestCase):
matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
6 + 6j]]).astype(dtype)
expected_transposed = np.conj(matrix_np.T)
- with self.test_session():
+ with self.cached_session():
matrix = ops.convert_to_tensor(matrix_np)
transposed = linalg.adjoint(matrix)
self.assertEqual((3, 2), transposed.get_shape())
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index ee86cf0b24..baeb40dd63 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -42,7 +42,7 @@ class ListDiffTest(test.TestCase):
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index e635a71c78..82729b9e27 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import test
class LoggingOpsTest(test.TestCase):
def testAssertDivideByZero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
epsilon = ops.convert_to_tensor(1e-20)
x = ops.convert_to_tensor(0.0)
y = ops.convert_to_tensor(1.0)
@@ -66,7 +66,7 @@ class PrintGradientTest(test.TestCase):
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
def testPrintGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
w = constant_op.constant(4.0, shape=[10, 100], name="w")
wx = math_ops.matmul(w, inp, name="wx")
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 5f08339fe5..38b14e34cc 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.training import server_lib
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -54,7 +54,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -72,7 +72,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -90,7 +90,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -107,7 +107,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -135,7 +135,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -173,7 +173,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -190,7 +190,7 @@ class HashTableOpTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup_ops.HashTable(
@@ -198,7 +198,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(
@@ -211,7 +211,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -223,7 +223,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -272,7 +272,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -284,7 +284,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -299,7 +299,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -314,7 +314,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -328,7 +328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -344,7 +344,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -359,7 +359,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -374,7 +374,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -385,7 +385,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -432,7 +432,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -444,7 +444,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -459,7 +459,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -471,7 +471,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -490,14 +490,14 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_table_ref_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
self.assertIsNotNone(table.table_ref)
def test_index_table_from_file_table_ref_without_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=0)
self.assertIsNotNone(table.table_ref)
@@ -506,21 +506,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int32, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
@@ -532,7 +532,7 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
@@ -542,7 +542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -553,7 +553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -565,7 +565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
default_value=default_value)
@@ -576,14 +576,14 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
"vocabulary_list must be specified"):
lookup_ops.index_table_from_tensor(
vocabulary_list=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -593,7 +593,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
@@ -623,7 +623,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
type_funcs = [str, constant_op.constant]
for type_func in type_funcs:
vocabulary_file = type_func(vocabulary_path)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(
@@ -636,7 +636,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -650,7 +650,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -665,7 +665,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -677,7 +677,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -690,7 +690,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -702,7 +702,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -715,7 +715,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -729,7 +729,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["hello", "hello"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -740,7 +740,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list, default_value=default_value)
@@ -764,7 +764,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -782,7 +782,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -800,7 +800,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.LINE_NUMBER
value_index = lookup_ops.TextFileIndex.WHOLE_LINE
@@ -821,7 +821,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -843,7 +843,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -857,7 +857,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.WHOLE_LINE
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -870,7 +870,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -885,7 +885,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup_ops.HashTable(
@@ -924,7 +924,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup_ops.HashTable(
@@ -934,7 +934,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -982,7 +982,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -1008,7 +1008,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -1031,7 +1031,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1048,7 +1048,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1065,7 +1065,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1090,7 +1090,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1110,7 +1110,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1132,7 +1132,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1151,7 +1151,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1172,7 +1172,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1194,20 +1194,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -1248,7 +1248,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1269,7 +1269,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1292,7 +1292,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -1328,7 +1328,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -1355,7 +1355,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -1383,7 +1383,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -1410,7 +1410,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1451,7 +1451,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
def testIdTableWithHashBucketsNoInnerTable(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
self.assertIsNone(table.table_ref)
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 87fc715783..3ce0b74263 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -61,62 +61,62 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = losses.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = losses.absolute_difference(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = losses.absolute_difference(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant((1.2, 0.0), shape=(2, 1))
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@test_util.assert_no_new_pyobjects_executing_eagerly
@@ -134,12 +134,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.softmax_cross_entropy(labels, logits, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
@@ -152,7 +152,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -162,7 +162,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -171,7 +171,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -181,7 +181,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant((1.2, 3.4, 5.6))
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -190,7 +190,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -199,12 +199,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -215,7 +215,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
losses.softmax_cross_entropy(labels, logits, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -242,12 +242,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.sparse_softmax_cross_entropy(labels, logits, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
@@ -263,7 +263,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
losses.sparse_softmax_cross_entropy(labels, logits)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64)
@@ -272,7 +272,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([0, 1, 2])
@@ -285,7 +285,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -315,7 +315,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -324,7 +324,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -334,7 +334,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(
labels, logits, constant_op.constant((weights,)))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -345,7 +345,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = array_ops.placeholder(dtypes.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
@@ -355,7 +355,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = array_ops.placeholder(dtypes.float32)
labels = array_ops.placeholder(dtypes.int32)
weights = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={
@@ -370,7 +370,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
weights = array_ops.placeholder(dtypes.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={
@@ -387,7 +387,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -396,7 +396,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -405,7 +405,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -414,12 +414,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -432,7 +432,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -445,7 +445,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -458,7 +458,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -472,7 +472,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -506,7 +506,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits, weights)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits, weights)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -531,7 +531,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -542,7 +542,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -562,7 +562,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertEquals(logits.dtype, loss.dtype)
self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testSigmoidFloat64(self):
@@ -577,7 +577,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(44.444, loss.eval(), 3)
def testSigmoidNoReduction(self):
@@ -590,7 +590,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
labels, logits, reduction=losses.Reduction.NONE)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose((
(0., 0., 0.),
(0., 100., 100.),
@@ -598,7 +598,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
), loss.eval(), 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -621,7 +621,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -656,33 +656,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = losses.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = losses.log_loss(self._labels, tf_predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = losses.log_loss(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -690,7 +690,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -700,7 +700,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, tf_predictions,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -710,7 +710,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, tf_predictions,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -721,7 +721,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -730,7 +730,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -739,12 +739,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.log_loss(self._labels, self._predictions, weights)
@@ -757,7 +757,7 @@ class LogLossTest(test.TestCase):
self._predictions,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -771,7 +771,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -784,7 +784,7 @@ class LogLossTest(test.TestCase):
self._predictions,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -795,35 +795,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = losses.log_loss(self._labels, tf_predictions, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = losses.log_loss(self._labels, self._predictions, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = losses.hinge_loss(labels, logits).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = losses.hinge_loss(labels, logits)
self.assertAllClose(loss.eval(), 0.0, atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = losses.hinge_loss(labels, logits)
@@ -832,7 +832,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), 0.175, atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = losses.hinge_loss(labels, logits)
@@ -844,14 +844,14 @@ class HingeLossTest(test.TestCase):
class HuberLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = losses.huber_loss(labels, predictions).eval()
def testAllQuadratic(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
loss = losses.huber_loss(labels, predictions)
@@ -859,7 +859,7 @@ class HuberLossTest(test.TestCase):
0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5)
def testAllLinear(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
loss = losses.huber_loss(labels, predictions)
@@ -867,7 +867,7 @@ class HuberLossTest(test.TestCase):
(1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5)
def testMixedQuadraticLinear(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0],
[1.5, -1.4, -1.0, 0.0]])
labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5],
@@ -879,7 +879,7 @@ class HuberLossTest(test.TestCase):
self.assertAllClose(loss.eval(), expected_loss, atol=1e-5)
def testAllQuadraticDelta(self):
- with self.test_session():
+ with self.cached_session():
delta = 0.5
predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0])
labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
@@ -894,7 +894,7 @@ class HuberLossTest(test.TestCase):
expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean()
expected -= 0.5 * delta**2
loss = losses.huber_loss(labels, predictions, delta=delta)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected, loss.eval(), atol=1e-5)
@@ -906,13 +906,13 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testScalar(self):
- with self.test_session():
+ with self.cached_session():
self.assertEqual(
0.0,
losses.mean_squared_error(predictions=constant_op.constant(0),
@@ -920,55 +920,55 @@ class MeanSquaredErrorTest(test.TestCase):
def testAllCorrectNoLossWeight(self):
loss = losses.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = losses.mean_squared_error(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = losses.mean_squared_error(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -994,7 +994,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 3.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -1003,7 +1003,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
def _test_valid_weights(
self, labels, predictions, expected_loss, weights=1.0):
- with self.test_session():
+ with self.cached_session():
static_inputs_op = losses.mean_pairwise_squared_error(
predictions=predictions, labels=labels, weights=weights)
self.assertAlmostEqual(expected_loss, static_inputs_op.eval(), places=3)
@@ -1054,7 +1054,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -1073,7 +1073,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -1122,7 +1122,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
predictions=predictions_placeholder,
labels=labels_placeholder,
weights=weights_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
dynamic_inputs_op.eval(feed_dict={
predictions_placeholder: predictions,
@@ -1191,7 +1191,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
labels=array_ops.concat([labels0, labels1], 0),
predictions=array_ops.concat([predictions0, predictions1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1216,7 +1216,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1229,7 +1229,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1237,7 +1237,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1255,7 +1255,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = losses.cosine_distance(tf_labels, tf_preds, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1264,7 +1264,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=np.asarray((1, 0, 0)).reshape((3, 1, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1274,7 +1274,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
@@ -1286,7 +1286,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1296,7 +1296,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 1, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1305,7 +1305,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1411,7 +1411,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
self._raw_losses, weights=weight)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.mean(weight * self._raw_losses), weighted_loss.eval())
@@ -1429,7 +1429,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
self._raw_losses, weights=weights_placeholder)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
weighted_loss.eval(feed_dict={weights_placeholder: weights})
@@ -1452,7 +1452,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
raw_losses, weights=weights_placeholder)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
weighted_loss.eval(feed_dict={weights_placeholder: weights})
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index dc3ea38671..f71857a3cb 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase):
def _testRoll(self, np_input, shift, axis):
expected_roll = np.roll(np_input, shift, axis)
- with self.test_session():
+ with self.cached_session():
roll = manip_ops.roll(np_input, shift, axis)
self.assertAllEqual(roll.eval(), expected_roll)
def _testGradient(self, np_input, shift, axis):
- with self.test_session():
+ with self.cached_session():
inx = constant_op.constant(np_input.tolist())
xs = list(np_input.shape)
y = manip_ops.roll(inx, shift, axis)
@@ -94,7 +94,7 @@ class RollTest(test_util.TensorFlowTestCase):
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
# Make sure negative axis should be 0 <= axis + dims < dims
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
@@ -111,7 +111,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@@ -127,7 +127,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = 1
axis = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@@ -143,7 +143,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@@ -158,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@@ -167,7 +167,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [1, 2]
shift = 1
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(tensor, shift, axis).eval()
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index b167278984..309da8f184 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -206,7 +206,7 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]])
c = infix_matmul(a, b)
d = math_ops.matmul(a, b)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(c.eval(), d.eval())
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index f41967ff98..720ba806e9 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -114,7 +114,7 @@ class InverseOpTest(test.TestCase):
def testNotInvertible(self):
# The input should be invertible.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Input is not invertible."):
# All rows of the matrix below add to zero.
tensor3 = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index 33288392c0..dd01ba11af 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -143,7 +143,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, matrix)
with self.assertRaises(ValueError):
@@ -154,7 +154,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# right-hand sides.
matrix = np.array([[1., 0.], [0., 1.]])
rhs = np.array([[1., 0.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
@@ -164,7 +164,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# The input should be invertible.
# The matrix is singular because it has a zero on the diagonal.
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Input matrix is not invertible."):
self._verifySolve(singular_matrix, singular_matrix)
with self.assertRaisesOpError("Input matrix is not invertible."):
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index 55653489af..5dcdb9e420 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -192,7 +192,7 @@ class MeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -209,7 +209,7 @@ class MeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -253,7 +253,7 @@ class MeanTest(test.TestCase):
metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
expected = np.mean(values)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
for mean_result in mean_results:
mean, update_op = mean_result
@@ -266,7 +266,7 @@ class MeanTest(test.TestCase):
np.sum(np.multiply(weights, np.ones_like(values)))
)
mean, update_op = metrics.mean(values, weights=weights)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
self.assertAlmostEqual(expected, update_op.eval(), places=5)
self.assertAlmostEqual(expected, mean.eval(), places=5)
@@ -330,7 +330,7 @@ class MeanTest(test.TestCase):
# Dynamic shapes.
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
- with self.test_session():
+ with self.cached_session():
_, update_op = metrics.mean(values_placeholder, invalid_weight)
variables.local_variables_initializer().run()
update_op.eval(feed_dict={values_placeholder: values})
@@ -359,7 +359,7 @@ class MeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -376,7 +376,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -397,7 +397,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -418,7 +418,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testBinaryWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -445,7 +445,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -472,7 +472,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -499,7 +499,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -575,7 +575,7 @@ class AccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
accuracy, update_op = metrics.accuracy(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -588,7 +588,7 @@ class AccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -618,7 +618,7 @@ class AccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions)
sess.run(variables.local_variables_initializer())
@@ -628,7 +628,7 @@ class AccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizesWithScalarWeight(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
sess.run(variables.local_variables_initializer())
@@ -642,7 +642,7 @@ class AccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions, weights)
sess.run(variables.local_variables_initializer())
@@ -662,7 +662,7 @@ class AccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions,
weights_placeholder)
@@ -674,7 +674,7 @@ class AccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -746,7 +746,7 @@ class PrecisionTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -765,7 +765,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -778,7 +778,7 @@ class PrecisionTest(test.TestCase):
constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -789,7 +789,7 @@ class PrecisionTest(test.TestCase):
precision, update_op = metrics.precision(
labels, predictions, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -806,7 +806,7 @@ class PrecisionTest(test.TestCase):
}
precision, update_op = metrics.precision(labels, predictions, weights=2)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 2.0
weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
@@ -826,7 +826,7 @@ class PrecisionTest(test.TestCase):
precision, update_op = metrics.precision(
labels, predictions, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -844,7 +844,7 @@ class PrecisionTest(test.TestCase):
predictions,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -864,7 +864,7 @@ class PrecisionTest(test.TestCase):
predictions,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -881,7 +881,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -891,7 +891,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -933,7 +933,7 @@ class RecallTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -952,7 +952,7 @@ class RecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -965,7 +965,7 @@ class RecallTest(test.TestCase):
constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -976,7 +976,7 @@ class RecallTest(test.TestCase):
weights = constant_op.constant([[2], [5]])
recall, update_op = metrics.recall(labels, predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -990,7 +990,7 @@ class RecallTest(test.TestCase):
weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
recall, update_op = metrics.recall(labels, predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1005,7 +1005,7 @@ class RecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1015,7 +1015,7 @@ class RecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1055,7 +1055,7 @@ class AUCTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
auc, update_op = metrics.auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1073,7 +1073,7 @@ class AUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.auc(labels, predictions, curve=curve)
@@ -1084,7 +1084,7 @@ class AUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect_multipleLabelDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for label_dtype in (
dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
predictions = constant_op.constant(
@@ -1099,7 +1099,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1112,7 +1112,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1127,7 +1127,7 @@ class AUCTest(test.TestCase):
# Regarding the AUC-PR tests: note that the preferred method when
# calculating AUC-PR is summation_method='careful_interpolation'.
def testCorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1141,7 +1141,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testCorrectAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1157,7 +1157,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testThirdCorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1173,7 +1173,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1186,7 +1186,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1201,7 +1201,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1218,7 +1218,7 @@ class AUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.auc(labels, predictions)
@@ -1229,7 +1229,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.auc(labels, predictions)
@@ -1240,7 +1240,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.auc(labels, predictions, curve='PR')
@@ -1301,7 +1301,7 @@ class AUCTest(test.TestCase):
scale=1.0, size=num_samples)):
expected_auc = self.np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1370,7 +1370,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1390,7 +1390,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -1405,7 +1405,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -1420,7 +1420,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1439,7 +1439,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1457,7 +1457,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -1507,7 +1507,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1527,7 +1527,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -1542,7 +1542,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -1557,7 +1557,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -1576,7 +1576,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -1638,7 +1638,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
thresholds)
rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates, then verify idempotency.
@@ -1654,7 +1654,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -1670,7 +1670,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect_multipleLabelDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for label_dtype in (
dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
predictions = constant_op.constant(
@@ -1692,7 +1692,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -1708,7 +1708,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1738,7 +1738,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1768,7 +1768,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -1792,7 +1792,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -1842,7 +1842,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -2801,7 +2801,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_absolute_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2822,7 +2822,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.mean_absolute_error(labels, predictions, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -2866,7 +2866,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(labels, predictions,
normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2891,7 +2891,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(
labels, predictions, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -2907,7 +2907,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(
labels, predictions, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -2945,7 +2945,7 @@ class MeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2963,7 +2963,7 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -2976,7 +2976,7 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -2990,13 +2990,13 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3020,7 +3020,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3063,7 +3063,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3122,7 +3122,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.root_mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3135,7 +3135,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -3148,7 +3148,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -3161,7 +3161,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -3220,7 +3220,7 @@ class MeanCosineDistanceTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3242,7 +3242,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -3258,7 +3258,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -3279,7 +3279,7 @@ class MeanCosineDistanceTest(test.TestCase):
np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -3298,7 +3298,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(
labels, predictions, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -3317,7 +3317,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(
labels, predictions, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -3352,7 +3352,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -3369,7 +3369,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -3445,7 +3445,7 @@ class MeanIOUTest(test.TestCase):
mean_iou, update_op = metrics.mean_iou(
labels, predictions, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3459,7 +3459,7 @@ class MeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3490,7 +3490,7 @@ class MeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3538,7 +3538,7 @@ class MeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -3585,7 +3585,7 @@ class MeanIOUTest(test.TestCase):
],
0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
confusion_matrix = update_op.eval()
@@ -3597,7 +3597,7 @@ class MeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertEqual(40, update_op.eval()[0])
@@ -3607,7 +3607,7 @@ class MeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
@@ -3637,7 +3637,7 @@ class MeanIOUTest(test.TestCase):
0, shape=[1])
],
0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -3657,7 +3657,7 @@ class MeanIOUTest(test.TestCase):
[[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
@@ -3669,7 +3669,7 @@ class MeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
@@ -3687,7 +3687,7 @@ class MeanIOUTest(test.TestCase):
[[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
@@ -3751,7 +3751,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3764,7 +3764,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
self.assertEqual(initial_mean_accuracy, mean_accuracy.eval())
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3796,7 +3796,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3844,7 +3844,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -3880,7 +3880,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
@@ -3891,7 +3891,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
@@ -3910,7 +3910,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -3944,7 +3944,7 @@ class FalseNegativesTest(test.TestCase):
tn, tn_update_op = metrics.false_negatives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(3., tn_update_op.eval())
@@ -3963,7 +3963,7 @@ class FalseNegativesTest(test.TestCase):
tn, tn_update_op = metrics.false_negatives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(5., tn_update_op.eval())
@@ -3993,7 +3993,7 @@ class FalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.false_negatives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -4012,7 +4012,7 @@ class FalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -4043,7 +4043,7 @@ class FalsePositivesTest(test.TestCase):
tn, tn_update_op = metrics.false_positives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(7., tn_update_op.eval())
@@ -4062,7 +4062,7 @@ class FalsePositivesTest(test.TestCase):
tn, tn_update_op = metrics.false_positives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(14., tn_update_op.eval())
@@ -4092,7 +4092,7 @@ class FalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.false_positives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -4113,7 +4113,7 @@ class FalsePositivesAtThresholdsTest(test.TestCase):
(19.0, 23.0, 29.0, 31.0)),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -4144,7 +4144,7 @@ class TrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(3., tn_update_op.eval())
@@ -4163,7 +4163,7 @@ class TrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(4., tn_update_op.eval())
@@ -4193,7 +4193,7 @@ class TrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -4212,7 +4212,7 @@ class TrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -4243,7 +4243,7 @@ class TruePositivesTest(test.TestCase):
tn, tn_update_op = metrics.true_positives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(7., tn_update_op.eval())
@@ -4262,7 +4262,7 @@ class TruePositivesTest(test.TestCase):
tn, tn_update_op = metrics.true_positives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(12., tn_update_op.eval())
@@ -4292,7 +4292,7 @@ class TruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.true_positives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -4309,7 +4309,7 @@ class TruePositivesAtThresholdsTest(test.TestCase):
predictions=predictions, labels=labels, weights=37.0,
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index 944de217a1..e415d7879e 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -188,7 +188,7 @@ class PadOpTest(test.TestCase):
mode="SYMMETRIC").eval()
def testInvalid(self):
- with self.test_session():
+ with self.cached_session():
x = [[1, 2, 3], [4, 5, 6]]
with self.assertRaisesRegexp(ValueError, "Unknown padding mode"):
array_ops.pad(x, [[1, 0], [2, 1]], mode="weird").eval()
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index d8c3f9823c..95f3dcceea 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -95,13 +95,13 @@ class PaddingFIFOQueueTest(test.TestCase):
""", q.queue_ref.op.node_def)
def testEnqueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((3, 2),))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
@@ -111,14 +111,14 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(1, q.size().eval())
def testEnqueueManyWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
def testParallelEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -144,7 +144,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -168,7 +168,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -182,7 +182,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -212,7 +212,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10,
(dtypes_lib.int32, dtypes_lib.float32),
((), ()))
@@ -230,12 +230,12 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
@@ -248,7 +248,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, size.eval())
def testEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -261,7 +261,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elems[i % 4]], vals)
def testEmptyEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, (
(None, None),))
empty_t = constant_op.constant(
@@ -274,7 +274,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([0], size_t.eval())
def testEmptyDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, shapes=((),))
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_many(0)
@@ -284,7 +284,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueManyWithDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((None,),))
enqueue_op = q.enqueue(([10.0],))
@@ -295,7 +295,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueUpToWithDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((None,),))
enqueue_op = q.enqueue(([10.0],))
@@ -306,7 +306,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testConstructPaddingFIFOQueueWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
r"When providing partial shapes, a list of shapes must be provided."):
@@ -314,7 +314,7 @@ class PaddingFIFOQueueTest(test.TestCase):
None).queue_ref.eval()
def testMultiEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10,
(dtypes_lib.float32, dtypes_lib.int32),
((), (2,)))
@@ -332,7 +332,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testMultiEnqueueManyWithPartiallyKnownShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
float_elems = [10.0, 20.0, 30.0, 40.0]
@@ -349,7 +349,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -361,7 +361,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testDequeueUpToNoBlocking(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testMultiDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -404,7 +404,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
def testMultiDequeueManyWithPartiallyKnownShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
float_elems = [
@@ -443,7 +443,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testMultiDequeueManyWithPartiallyKnownShapesAndVariableSizeInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.string, dtypes_lib.int32),
shapes=((None,), (1, None)))
@@ -484,7 +484,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testMultiDequeueUpToPartiallyKnownShapesAndVariableInputNoBlocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.string, dtypes_lib.int32),
shapes=((None,), (1, None)))
@@ -525,7 +525,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, ((4, 4, 4, 4),))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
enqueue_op = q.enqueue_many((elems,))
@@ -535,7 +535,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(dequeued_t.eval(), elems)
def testPartiallyKnownHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, (
(4, None, 4, None),))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
@@ -592,7 +592,7 @@ class PaddingFIFOQueueTest(test.TestCase):
array_ops.placeholder(dtypes_lib.int32)))
def testEnqueueWrongPartiallyKnownShapeAtRuntime(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First dimension of second component is unknown, second
# dimension must be 3.
q = data_flow_ops.PaddingFIFOQueue(10,
@@ -607,7 +607,7 @@ class PaddingFIFOQueueTest(test.TestCase):
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongPartiallyKnownShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First dimension of second component is unknown, second
# dimension must be 3.
q = data_flow_ops.PaddingFIFOQueue(10,
@@ -625,7 +625,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testParallelEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(100)]
enqueue_op = q.enqueue_many((elems,))
@@ -644,7 +644,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -666,7 +666,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -690,7 +690,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
initial_elements = [10.0] * 49
q.enqueue_many((initial_elements,)).run()
@@ -723,7 +723,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertTrue(elem in (10.0, 20.0))
def testMixtureOfEnqueueAndEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -759,7 +759,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testMixtureOfDequeueAndDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
dequeued_t = q.dequeue()
@@ -793,7 +793,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -820,7 +820,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -847,7 +847,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testDequeueManyWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.PaddingFIFOQueue(100, dtypes_lib.int32, ((),))
@@ -872,7 +872,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(elems, dequeued_elems)
def testDequeueFromClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -890,7 +890,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testBlockingDequeueFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -916,7 +916,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testDequeueUpToFromClosedQueueReturnsRemainder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -938,7 +938,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue()
@@ -958,7 +958,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -983,7 +983,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1008,7 +1008,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1045,7 +1045,7 @@ class PaddingFIFOQueueTest(test.TestCase):
close_thread.join()
def testClosedBlockingDequeueManyRestoresPartialBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, (dtypes_lib.float32,
dtypes_lib.float32), ((), ()))
elems_a = [1.0, 2.0, 3.0]
@@ -1078,7 +1078,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueManyFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_many(4)
@@ -1098,7 +1098,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_up_to(4)
@@ -1118,7 +1118,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
close_op = q.close()
@@ -1131,7 +1131,7 @@ class PaddingFIFOQueueTest(test.TestCase):
enqueue_op.run()
def testEnqueueManyToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1145,7 +1145,7 @@ class PaddingFIFOQueueTest(test.TestCase):
enqueue_op.run()
def testBlockingEnqueueToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1168,7 +1168,7 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueManyToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1195,7 +1195,7 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1232,7 +1232,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingEnqueueManyBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1265,7 +1265,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(elem, dequeued_t.eval())
def testDoesNotLoseValue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(1, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
size_t = q.size()
@@ -1275,7 +1275,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(size_t.eval(), [1])
def testSharedQueueSameSession(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.PaddingFIFOQueue(
1, dtypes_lib.float32, ((),), shared_name="shared_queue")
q1.enqueue((10.0,)).run()
@@ -1305,7 +1305,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), [0])
def testIncompatibleSharedQueueErrors(self):
- with self.test_session():
+ with self.cached_session():
q_a_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_a")
q_a_2 = data_flow_ops.PaddingFIFOQueue(
@@ -1356,7 +1356,7 @@ class PaddingFIFOQueueTest(test.TestCase):
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
- with self.test_session():
+ with self.cached_session():
num_queues = 10
qlist = list()
for _ in xrange(num_queues):
@@ -1370,7 +1370,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(q.dequeue().eval(), 10.0)
def testSelectQueueOutOfRange(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
q2 = data_flow_ops.PaddingFIFOQueue(15, dtypes_lib.float32, ((),))
enq_q = data_flow_ops.PaddingFIFOQueue.from_list(3, [q1, q2])
@@ -1394,7 +1394,7 @@ class PaddingFIFOQueueTest(test.TestCase):
sess.run(enqueue_many_op)
def testResetOfBlockingOperation(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
dequeue_op = q_empty.dequeue()
dequeue_many_op = q_empty.dequeue_many(1)
@@ -1422,7 +1422,7 @@ class PaddingFIFOQueueTest(test.TestCase):
t.join()
def testBigEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.int32, ((),))
elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
enq = q.enqueue_many((elem,))
@@ -1467,7 +1467,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testBigDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(2, dtypes_lib.int32, ((),))
elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
@@ -1493,7 +1493,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
diff --git a/tensorflow/python/kernel_tests/parse_single_example_op_test.py b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
index bf4c89b368..a84895a287 100644
--- a/tensorflow/python/kernel_tests/parse_single_example_op_test.py
+++ b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
@@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -844,7 +844,7 @@ class ParseExampleTest(test.TestCase):
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 7dff4501cc..71d8b60d3c 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -937,7 +937,7 @@ class ParseExampleTest(test.TestCase):
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1054,7 +1054,7 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values = expected_feat_list_values or {}
expected_length_values = expected_length_values or {}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1606,7 +1606,7 @@ class ParseSequenceExampleTest(test.TestCase):
class DecodeJSONExampleTest(test.TestCase):
def _testRoundTrip(self, examples):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
examples = np.array(examples, dtype=np.object)
json_tensor = constant_op.constant(
@@ -1696,7 +1696,7 @@ class DecodeJSONExampleTest(test.TestCase):
])
def testInvalidSyntax(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
json_tensor = constant_op.constant(["{]"])
binary_tensor = parsing_ops.decode_json_example(json_tensor)
with self.assertRaisesOpError("Error while parsing JSON"):
@@ -1706,7 +1706,7 @@ class DecodeJSONExampleTest(test.TestCase):
class ParseTensorOpTest(test.TestCase):
def testToFloat32(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.float32)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1719,7 +1719,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testToUint8(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1732,7 +1732,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testTypeMismatch(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1745,7 +1745,7 @@ class ParseTensorOpTest(test.TestCase):
tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
def testInvalidInput(self):
- with self.test_session():
+ with self.cached_session():
serialized = array_ops.placeholder(dtypes.string)
tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 15d5702252..b34d30f5c0 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -39,7 +39,7 @@ from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
def testFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable(
@@ -50,7 +50,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, (5, 1))
def testFixedSizePartitionerInt64(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20])
@@ -58,7 +58,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertEqual(len(v0_list), 4)
def testResourceFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope(
"root", partitioner=partitioner, use_resource=True):
@@ -88,7 +88,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testVariableAxisSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Create a partitioned variable of shape (4, 8, 16, 32) type float32
# Bytes per slice along the given axes:
@@ -210,7 +210,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testMinMaxVariablePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
self._testMinMaxVariablePartitioner(
max_partitions=100,
@@ -323,7 +323,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec)
def testVecConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([1, 2, 3, 4])
vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
variables.global_variables_initializer().run()
@@ -334,7 +334,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
def testConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
rnd_par)
@@ -346,7 +346,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
def _testNameHelper(self, use_resource=False):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope("hi", use_resource=use_resource):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -363,7 +363,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test same variable.
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope(
"hola", use_resource=use_resource) as vs:
@@ -383,7 +383,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test name_scope
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with ops.name_scope("ola"):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -408,7 +408,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._testNameHelper(use_resource=True)
def testRandomInitValue(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([200, 40]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 10], rnd.initialized_value())
@@ -425,7 +425,7 @@ class PartitionedVariablesTestCase(test.TestCase):
])
def testRandomInitUnevenPartitions(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(
random_ops.random_uniform([20, 43], dtype=dtypes.float64))
var_lists = [
@@ -463,7 +463,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, save_specs[i])
def testDegenerate(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 1], rnd.initialized_value())
@@ -474,7 +474,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
def testSliceSizeOne(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [10, 1], rnd.initialized_value())
@@ -492,7 +492,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
_IotaInitializer([4, 2]))
- with self.test_session():
+ with self.cached_session():
vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
_IotaInitializer)
variables.global_variables_initializer().run()
@@ -506,7 +506,7 @@ class PartitionedVariablesTestCase(test.TestCase):
def testRandomInitializer(self):
# Sanity check that the slices uses a different seed when using a random
# initializer function.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer())
variables.global_variables_initializer().run()
@@ -514,7 +514,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
# Negative test that proves that slices have the same values if
# the random initializer uses a seed.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
variables.global_variables_initializer().run()
@@ -522,7 +522,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose(val0, val1)
def testSomeErrors(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
with self.assertRaises(ValueError):
partitioned_variables.create_partitioned_variables(
@@ -547,7 +547,7 @@ class PartitionedVariablesTestCase(test.TestCase):
[10, 43], [1, 50], rnd.initialized_value())
def testControlDepsNone(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dependency.
@@ -573,7 +573,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual([], op.control_inputs)
def testConcat(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
var_x = variable_scope.get_variable(
"x",
initializer=constant_op.constant([1., 2.]),
diff --git a/tensorflow/python/kernel_tests/priority_queue_test.py b/tensorflow/python/kernel_tests/priority_queue_test.py
index 3fb9c9c468..73a9c81638 100644
--- a/tensorflow/python/kernel_tests/priority_queue_test.py
+++ b/tensorflow/python/kernel_tests/priority_queue_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import test
class PriorityQueueTest(test.TestCase):
def testRoundTripInsertReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -67,7 +67,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripInsertMultiThreadedReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -113,7 +113,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
num_threads = 40
@@ -163,7 +163,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
num_threads = 40
@@ -219,7 +219,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(set(dequeued), set(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -268,7 +268,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripInsertOnceReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
@@ -289,7 +289,7 @@ class PriorityQueueTest(test.TestCase):
self.assertTrue((dv0, dv1) in allowed[e])
def testRoundTripInsertOnceReadManySorts(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
q.enqueue_many((elem, elem)).run()
@@ -297,7 +297,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(deq_values, sorted(elem))
def testRoundTripInsertOnceReadOnceLotsSorts(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
q.enqueue_many((elem, elem)).run()
@@ -306,13 +306,13 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(deq_values, sorted(elem))
def testInsertingNonInt64Fails(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (()))
with self.assertRaises(TypeError):
q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()
def testInsertingNonScalarFails(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_priority = array_ops.placeholder(dtypes.int64)
input_other = array_ops.placeholder(dtypes.string)
q = data_flow_ops.PriorityQueue(2000, (dtypes.string,), (()))
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 8e06e1abfb..8c84b2a49f 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -146,7 +146,7 @@ class IdentityReaderTest(test.TestCase):
self.assertAllEqual(expected, v)
def testOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
@@ -180,7 +180,7 @@ class IdentityReaderTest(test.TestCase):
self.assertAllEqual(0, queued_length.eval())
def testMultipleEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([["DD", "EE"]])
@@ -201,7 +201,7 @@ class IdentityReaderTest(test.TestCase):
sess.run([key, value])
def testSerializeRestore(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
produced = reader.num_records_produced()
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
@@ -256,7 +256,7 @@ class IdentityReaderTest(test.TestCase):
reader.restore_state(b"BOGUS" + state[5:]).run()
def testReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
@@ -307,7 +307,7 @@ class WholeFileReaderTest(test.TestCase):
self.assertAllEqual(self._content[index], v)
def testOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([self._filenames]).run()
@@ -323,7 +323,7 @@ class WholeFileReaderTest(test.TestCase):
sess.run([key, value])
def testInfiniteEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([self._filenames])
@@ -366,7 +366,7 @@ class TextLineReaderTest(test.TestCase):
return filenames
def _testOneEpoch(self, files):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TextLineReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -391,7 +391,7 @@ class TextLineReaderTest(test.TestCase):
def testSkipHeaderLines(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -522,7 +522,7 @@ class FixedLengthRecordReaderTest(TFCompressionTestCase):
# gap_bytes=hop_bytes-record_bytes
def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
@@ -549,7 +549,7 @@ class FixedLengthRecordReaderTest(TFCompressionTestCase):
files,
num_overlapped_records,
encoding=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
@@ -621,7 +621,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
def testOneEpoch(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -640,7 +640,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
def testReadUpTo(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
batch_size = 3
@@ -670,7 +670,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
files = self._CreateFiles(options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -687,7 +687,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
files = self._CreateFiles(options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -752,7 +752,7 @@ class LMDBReaderTest(test.TestCase):
shutil.copy(path, self.db_path)
def testReadFromFile(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_file")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -770,7 +770,7 @@ class LMDBReaderTest(test.TestCase):
k, v = sess.run([key, value])
def testReadFromSameFile(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
filename_queue = input_lib.string_input_producer(
@@ -789,7 +789,7 @@ class LMDBReaderTest(test.TestCase):
coord.join(threads)
def testReadFromFolder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_folder")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -807,7 +807,7 @@ class LMDBReaderTest(test.TestCase):
k, v = sess.run([key, value])
def testReadFromFileRepeatedly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
filename_queue = input_lib.string_input_producer(
[self.db_path], num_epochs=None)
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 068860d5d4..ebb9872f22 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -44,7 +44,7 @@ class RecordInputOpTest(test.TestCase):
w.close()
def testRecordInputSimple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", 1, 1)
yield_op = data_flow_ops.RecordInput(
@@ -57,7 +57,7 @@ class RecordInputOpTest(test.TestCase):
self.assertEqual(sess.run(yield_op), b"0000000000")
def testRecordInputSimpleGzip(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData(
"basic",
1,
@@ -76,7 +76,7 @@ class RecordInputOpTest(test.TestCase):
self.assertEqual(sess.run(yield_op), b"0000000000")
def testRecordInputSimpleZlib(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData(
"basic",
1,
@@ -98,7 +98,7 @@ class RecordInputOpTest(test.TestCase):
files = 100
records_per_file = 100
batches = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", files, records_per_file)
records = data_flow_ops.RecordInput(
@@ -126,7 +126,7 @@ class RecordInputOpTest(test.TestCase):
def testDoesNotDeadlock(self):
# Iterate multiple times to cause deadlock if there is a chance it can occur
for _ in range(30):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", 1, 1)
records = data_flow_ops.RecordInput(
@@ -141,7 +141,7 @@ class RecordInputOpTest(test.TestCase):
sess.run(yield_op)
def testEmptyGlob(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
record_input = data_flow_ops.RecordInput(file_pattern="foo")
yield_op = record_input.get_yield_op()
sess.run(variables.global_variables_initializer())
@@ -152,7 +152,7 @@ class RecordInputOpTest(test.TestCase):
files = 10
records_per_file = 10
batches = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", files, records_per_file)
records = data_flow_ops.RecordInput(
diff --git a/tensorflow/python/kernel_tests/reduce_join_op_test.py b/tensorflow/python/kernel_tests/reduce_join_op_test.py
index 663561ced7..3bb4986313 100644
--- a/tensorflow/python/kernel_tests/reduce_join_op_test.py
+++ b/tensorflow/python/kernel_tests/reduce_join_op_test.py
@@ -113,7 +113,7 @@ class ReduceJoinTest(UnicodeTestCase):
keep_dims: Whether or not to retain reduced dimensions.
separator: The separator to use for joining.
"""
- with self.test_session():
+ with self.cached_session():
output = string_ops.reduce_join(
inputs=input_array,
axis=axis,
@@ -136,7 +136,7 @@ class ReduceJoinTest(UnicodeTestCase):
axis: The indices to reduce.
separator: The separator to use when joining.
"""
- with self.test_session():
+ with self.cached_session():
output = string_ops.reduce_join(
inputs=input_array, axis=axis, keep_dims=False, separator=separator)
output_keep_dims = string_ops.reduce_join(
@@ -234,7 +234,7 @@ class ReduceJoinTest(UnicodeTestCase):
input_array = [["a"], ["b"]]
truth = ["ab"]
truth_shape = None
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
reduced = string_ops.reduce_join(placeholder, axis=0)
output_array = reduced.eval(feed_dict={placeholder.name: input_array})
@@ -247,7 +247,7 @@ class ReduceJoinTest(UnicodeTestCase):
truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"]
truth_dim_one = ["thisisatest", "pleasedonotpanic"]
truth_shape = None
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
reduced = string_ops.reduce_join(input_array, axis=placeholder)
output_array_dim_zero = reduced.eval(feed_dict={placeholder.name: [0]})
@@ -298,7 +298,7 @@ class ReduceJoinTest(UnicodeTestCase):
self._testMultipleReduceJoin(input_array, axis=permutation)
def testInvalidReductionIndices(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"):
string_ops.reduce_join(inputs="", axis=0)
with self.assertRaisesRegexp(ValueError,
@@ -313,7 +313,7 @@ class ReduceJoinTest(UnicodeTestCase):
string_ops.reduce_join(inputs=[[""]], axis=[0, 2])
def testZeroDims(self):
- with self.test_session():
+ with self.cached_session():
inputs = np.zeros([0, 1], dtype=str)
# Reduction that drops the dim of size 0.
@@ -326,7 +326,7 @@ class ReduceJoinTest(UnicodeTestCase):
self.assertAllEqual([0], output_shape)
def testInvalidArgsUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
index_too_high = string_ops.reduce_join(placeholder, axis=1)
duplicate_index = string_ops.reduce_join(placeholder, axis=[-1, 1])
@@ -336,7 +336,7 @@ class ReduceJoinTest(UnicodeTestCase):
duplicate_index.eval(feed_dict={placeholder.name: [[""]]})
def testInvalidArgsUnknownIndices(self):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
reduced = string_ops.reduce_join(["test", "test2"], axis=placeholder)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index ea78b58d88..496a452a03 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -61,7 +61,7 @@ class ReducedShapeTest(test.TestCase):
self.assertAllEqual(output.eval(), result)
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
self._check([3], [], [3])
self._check([3], [0], [1])
self._check([5, 3], [], [5, 3])
@@ -71,7 +71,7 @@ class ReducedShapeTest(test.TestCase):
def testZeros(self):
"""Check that reduced_shape does the right thing with zero dimensions."""
- with self.test_session():
+ with self.cached_session():
self._check([0], [], [0])
self._check([0], [0], [1])
self._check([0, 3], [], [0, 3])
@@ -84,7 +84,7 @@ class ReducedShapeTest(test.TestCase):
self._check([3, 0], [0, 1], [1, 1])
def testNegAxes(self):
- with self.test_session():
+ with self.cached_session():
self._check([10, 10, 10], [-1], [10, 10, 1])
self._check([10, 10, 10], [-1, 2], [10, 10, 1])
self._check([10, 10, 10], [-1, -1], [10, 10, 1])
@@ -95,7 +95,7 @@ class ReducedShapeTest(test.TestCase):
class ReductionUnknownShape(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
for dtype, reductions in [(dtypes.float32,
(math_ops.reduce_sum, math_ops.reduce_mean,
math_ops.reduce_prod, math_ops.reduce_max,
@@ -617,7 +617,7 @@ class MinReductionTest(test.TestCase):
def testGradient(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [1, 2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -627,7 +627,7 @@ class MinReductionTest(test.TestCase):
def testGradient2(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [1])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -637,7 +637,7 @@ class MinReductionTest(test.TestCase):
def testGradient3(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -647,7 +647,7 @@ class MinReductionTest(test.TestCase):
def testGradient4(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -655,7 +655,7 @@ class MinReductionTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testEmptyGradients(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_min(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -744,7 +744,7 @@ class MaxReductionTest(test.TestCase):
def testGradient(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [1, 2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -754,7 +754,7 @@ class MaxReductionTest(test.TestCase):
def testGradient2(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [1])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -764,7 +764,7 @@ class MaxReductionTest(test.TestCase):
def testGradient3(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -774,7 +774,7 @@ class MaxReductionTest(test.TestCase):
def testGradient4(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -782,7 +782,7 @@ class MaxReductionTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testEmptyGradients(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_max(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -960,7 +960,7 @@ class CountNonzeroReductionTest(test.TestCase):
def testStringReduce(self):
# Test case for GitHub issue 18712
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = math_ops.count_nonzero(constant_op.constant(["test"]))
self.assertAllClose(sess.run(v), 1)
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 7bd8c3ca27..e81f562a2a 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -35,7 +35,7 @@ class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
@@ -49,14 +49,14 @@ class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
def testEmptyMatch(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
def testInvalidPattern(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
matched = op(input_tensor, invalid_pattern)
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index f0e84b8fca..feac3a8b08 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_string_ops
@@ -100,22 +99,20 @@ class RegexReplaceTest(test.TestCase, parameterized.TestCase):
(as_tensor, as_string),
(as_tensor, as_tensor))
def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
- with compat.forward_compatibility_horizon(2018, 10, 11):
- with self.test_session():
- input_vector = constant_op.constant("foo", dtypes.string)
- pattern = pattern_fn("[a-z]")
- replace = rewrite_fn(".")
- op = string_ops.regex_replace(input_vector, pattern, replace)
- self.assertTrue(op.name.startswith("RegexReplace"))
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = pattern_fn("[a-z]")
+ replace = rewrite_fn(".")
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("RegexReplace"))
def testStaticRegexReplaceDelegation(self):
- with compat.forward_compatibility_horizon(2018, 10, 11):
- with self.test_session():
- input_vector = constant_op.constant("foo", dtypes.string)
- pattern = "[a-z]"
- replace = "."
- op = string_ops.regex_replace(input_vector, pattern, replace)
- self.assertTrue(op.name.startswith("StaticRegexReplace"))
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ replace = "."
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("StaticRegexReplace"))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 657d92fa23..a45a325b47 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -104,7 +104,7 @@ class ReluTest(test.TestCase):
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -149,7 +149,7 @@ class ReluTest(test.TestCase):
self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -166,7 +166,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -183,7 +183,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -201,7 +201,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradientScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(100.)
y = nn_ops.relu(x)
loss = y**2
@@ -249,7 +249,7 @@ class Relu6Test(test.TestCase):
# not well defined at around zero and six and we want to avoid that
# in terms of input values.
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
shape=[2, 5],
@@ -265,7 +265,7 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
shape=[2, 5],
@@ -313,7 +313,7 @@ class EluTest(test.TestCase):
use_gpu=True)
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, name="x")
y = nn_ops.elu(x, name="elu")
@@ -324,7 +324,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
y = nn_ops.elu(x, name="elu")
@@ -335,7 +335,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-6)
def testGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
elu = nn_ops.elu(x)
g, = gradients_impl.gradients(elu, x)
@@ -346,7 +346,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -363,7 +363,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -415,7 +415,7 @@ class SeluTest(test.TestCase):
use_gpu=True)
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, name="x")
y = nn_ops.selu(x, name="selu")
@@ -426,7 +426,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
y = nn_ops.selu(x, name="selu")
@@ -437,7 +437,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-6)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -454,7 +454,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -503,7 +503,7 @@ class CreluTest(test.TestCase):
use_gpu=True)
def testNumbersWithAxis0(self):
- with self.test_session():
+ with self.cached_session():
crelu = nn_ops.crelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
tf_relu = crelu.eval()
@@ -512,7 +512,7 @@ class CreluTest(test.TestCase):
self.assertAllEqual(np_crelu, tf_relu)
def testNumbersWithAxis1(self):
- with self.test_session():
+ with self.cached_session():
crelu = nn_ops.crelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
tf_relu = crelu.eval()
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index ef9b439230..ca3ff1d1df 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -94,7 +94,7 @@ class ReshapeTest(test.TestCase):
def testFloatReshapeGradThreeDimensions(self):
x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32)
s = list(np.shape(x))
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x)
reshape_out = array_ops.reshape(input_tensor, [1, 8, 3])
err = gradient_checker.compute_gradient_error(
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index 9beb615b2c..8fc71e0c57 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -120,7 +120,7 @@ class ReverseSequenceTest(test.TestCase):
batch_axis = 2
seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
- with self.test_session():
+ with self.cached_session():
input_t = constant_op.constant(x, shape=x.shape)
seq_lengths_t = constant_op.constant(seq_lengths, shape=seq_lengths.shape)
reverse_sequence_out = array_ops.reverse_sequence(
@@ -171,7 +171,7 @@ class ReverseSequenceTest(test.TestCase):
seq_axis=0,
batch_axis=3)
- with self.test_session():
+ with self.cached_session():
inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3))
seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,))
output = array_ops.reverse_sequence(
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f2f3023469..86e063cb36 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -294,7 +294,7 @@ class StatefulScatterNdTest(test.TestCase):
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
ref.initializer.run()
self.assertAllEqual(expected_result, scatter_update.eval())
@@ -409,7 +409,7 @@ class ScatterNdTest(test.TestCase):
expected = np.array([b"", b"one", b"", b"three", b"four",
b"", b"", b"seven"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -420,7 +420,7 @@ class ScatterNdTest(test.TestCase):
dtype=dtypes.string)
expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -432,7 +432,7 @@ class ScatterNdTest(test.TestCase):
expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertTrue(np.array_equal(result, expected[0]) or
np.array_equal(result, expected[1]))
@@ -451,7 +451,7 @@ class ScatterNdTest(test.TestCase):
scatter = self.scatter_nd(indices, updates, shape)
self.assertAllEqual(scatter.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_result, scatter.eval())
def testUndefinedIndicesShape(self):
@@ -486,7 +486,7 @@ class ScatterNdTest(test.TestCase):
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
"Indices and updates specified for empty output"):
self.scatter_nd(indices, updates, shape).eval(feed_dict={
@@ -500,7 +500,7 @@ class ScatterNdTest(test.TestCase):
shape = constant_op.constant([0], dtypes.int32)
scatter = self.scatter_nd(indices, updates, shape)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(scatter.eval().size, 0)
def testRank3InvalidShape1(self):
@@ -531,7 +531,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([1, 4], dtype=np.float64)
expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -548,7 +548,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -570,7 +570,7 @@ class ScatterNdTest(test.TestCase):
[[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
expected_input_grad = np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -607,7 +607,7 @@ class ScatterNdTest(test.TestCase):
[[[[1, 2], [3, 4]]]],
[[[[5, 6], [7, 8]]]]
]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -616,33 +616,33 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
- with self.test_session():
+ with self.cached_session():
val = self.scatter_nd(indices, values, shape).eval()
self.assertAllClose([np.sum(values)], val)
def testSmokeScatterNdBatch2DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
values = array_ops.zeros([3, 5, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
values = array_ops.zeros([0, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 2, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index a82855dfeb..ce507e4ad7 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -177,7 +177,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid1(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [-1, -1, 0, 0]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -188,7 +188,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid2(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [0, 1, 0, 1]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -197,7 +197,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid3(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [0, 1, 2, 0]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -233,7 +233,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min,
math_ops.segment_max
]:
- with self.test_session():
+ with self.cached_session():
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
s = tf_op(data=tf_x, segment_ids=indices)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -736,7 +736,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
segment_indices = [0, 1, 2, 2]
num_indices = len(segment_indices)
for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
- with self.test_session():
+ with self.cached_session():
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtypes_lib.float64)
s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
@@ -758,7 +758,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
math_ops.sparse_segment_sum_with_num_segments,
math_ops.sparse_segment_mean_with_num_segments,
]:
- with self.test_session():
+ with self.cached_session():
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtypes_lib.float64)
s = tf_op(
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
index 678016b13d..03e1ae852f 100644
--- a/tensorflow/python/kernel_tests/session_ops_test.py
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import test
class SessionOpsTest(test.TestCase):
def testHandleBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -45,7 +45,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
def testHandleEval(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -57,7 +57,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(50, h.eval())
def testHandleAndValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle and a value.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -70,7 +70,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(500, v)
def testHandleCond(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle and a value
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -90,7 +90,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(5000, result)
def testHandleForLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize a handle.
a = constant_op.constant(0)
h = session_ops.get_session_handle(a)
@@ -107,7 +107,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(100, h.eval())
def testHandleWhileLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize a handle.
a = constant_op.constant(0)
h = session_ops.get_session_handle(a)
@@ -127,7 +127,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(101, h.eval())
def testHandleMover(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -148,7 +148,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
def testHandleDelete(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -157,7 +157,7 @@ class SessionOpsTest(test.TestCase):
sess.run(h).delete()
def testHandleDeleteRaw(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -171,7 +171,7 @@ class SessionOpsTest(test.TestCase):
sess.run(x, feed_dict={f: raw_h})
def testMultiDevices(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device(test.gpu_device_name()):
a = constant_op.constant(1.0)
a_handle = sess.run(session_ops.get_session_handle(a))
@@ -189,7 +189,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(3.0, c_handle.eval())
def testHandleGC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# initial values live on CPU
with ops.device("/cpu:0"):
one = constant_op.constant(1, dtype=dtypes.float32)
@@ -213,7 +213,7 @@ class SessionOpsTest(test.TestCase):
add_h2: x_handle.handle})
def testHandlePlacement(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(1.0)
a_handle_op = session_ops.get_session_handle(a)
b = constant_op.constant(2.0)
@@ -233,7 +233,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(3.0, c_handle.eval())
def testFeedOneHandleDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -244,7 +244,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
def testDirectHandleFeedOverlappingWithFetches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -270,7 +270,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(50.0, d_val)
def testFeedTwoHandlesDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -284,7 +284,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
def testFeedHandleToVariableDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables.Variable(12.0)
inc_a = state_ops.assign_add(a, 2.0)
b = math_ops.add(a, 5.0)
diff --git a/tensorflow/python/kernel_tests/sets_test.py b/tensorflow/python/kernel_tests/sets_test.py
index 52b723802f..8335e9c139 100644
--- a/tensorflow/python/kernel_tests/sets_test.py
+++ b/tensorflow/python/kernel_tests/sets_test.py
@@ -158,7 +158,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
for op in ops:
self.assertEqual(None, op.get_shape().dims)
self.assertEqual(dtypes.int32, op.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
results = sess.run(ops)
self.assertAllEqual(results[0], results[1])
return results[0]
@@ -477,7 +477,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
dynamic_values_shape_ops = []
static_indices_shape = None
static_values_shape = None
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for op in ops:
if static_indices_shape is None:
static_indices_shape = op.indices.get_shape()
@@ -533,7 +533,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_intersection_count(self, a, b):
op = sets.set_size(sets.set_intersection(a, b))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def test_set_difference_multirow_2d(self):
@@ -971,7 +971,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_difference_count(self, a, b, aminusb=True):
op = sets.set_size(sets.set_difference(a, b, aminusb))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def test_set_union_multirow_2d(self):
@@ -1220,7 +1220,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_union_count(self, a, b):
op = sets.set_size(sets.set_union(a, b))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def _assert_set_operation(self, expected_indices, expected_values,
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 34e34d9d1b..0304dc3875 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -158,7 +158,7 @@ class ShapeOpsTest(test.TestCase):
# Disabled because it takes too long to run, but manually verified
# as passing at time of writing.
def _test64BitOutput(self):
- with self.test_session():
+ with self.cached_session():
inp = array_ops.zeros([2**31])
num_elements = array_ops.size_internal(
inp, optimize=False, out_type=dtypes.int64)
@@ -166,7 +166,7 @@ class ShapeOpsTest(test.TestCase):
# Too large for tf.int32 output.
with self.assertRaises(errors_impl.InvalidArgumentError):
- with self.test_session():
+ with self.cached_session():
inp = array_ops.zeros([2**31])
num_elements = array_ops.size_internal(
inp, optimize=False, out_type=dtypes.int32)
@@ -228,7 +228,7 @@ class ShapeOpsTest(test.TestCase):
self._compareExpandDimsAll(choice([2, 3, 5]), -4)
def testExpandDimsErrors(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaises(ValueError, array_ops.expand_dims,
np.zeros([2, 3, 5]), -5)
self.assertRaises(ValueError, array_ops.expand_dims,
@@ -239,7 +239,7 @@ class ShapeOpsTest(test.TestCase):
[False, True, True], 4)
def testExpandDimsGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(
np.random.rand(4, 2).astype("f"), dtype=dtypes.float32)
squeezed = array_ops.expand_dims(inp, 1)
@@ -249,7 +249,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testExpandDimsScalar(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(7)
self.assertAllEqual([7], array_ops.expand_dims(inp, 0).eval())
self.assertAllEqual([7], array_ops.expand_dims(inp, -1).eval())
@@ -375,7 +375,7 @@ class ShapeOpsTest(test.TestCase):
np.zeros([1, 2, 1]), [2, 3])
def testSqueezeGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = array_ops.reshape(inp, [4, 1, 2])
squeezed = array_ops.squeeze(a, [])
@@ -385,7 +385,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testSqueezeGradientWithSqueezeDims(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = array_ops.reshape(inp, [4, 1, 2, 1])
squeezed = array_ops.squeeze(a, [1])
@@ -395,7 +395,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testSqueezeWithUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtypes.float32, shape=[2, None])
squeezed = array_ops.squeeze(a, [1])
@@ -433,7 +433,7 @@ class TileTest(test.TestCase):
self.assertTrue((result == np.tile(inp, (1, 4))).all())
def testIdentityTileAndGrad(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype(np.float32)
a = constant_op.constant(inp)
tiled = array_ops.tile(a, [1, 1])
@@ -443,7 +443,7 @@ class TileTest(test.TestCase):
self.assertTrue((result == np.tile(inp, (1, 1))).all())
def testEmpty(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(2, 3).astype(np.float32)
a = constant_op.constant(inp)
tiled = array_ops.tile(a, [5, 0])
@@ -453,7 +453,7 @@ class TileTest(test.TestCase):
def testUnknownInputShape(self):
"""Importing can call _TileShape without shape of <multiples> known."""
- with self.test_session():
+ with self.cached_session():
inp = array_ops.placeholder(dtypes.float32) # unknown shape
multiples = constant_op.constant([1, 2, 3, 4], dtype=np.int32)
tiled = array_ops.tile(inp, multiples)
@@ -503,7 +503,7 @@ class TileTest(test.TestCase):
self.assertAllEqual(result, np.tile(inp, (1, 4)))
def testInvalidDim(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype("f")
a = constant_op.constant(
[float(x) for x in inp.ravel(order="C")],
@@ -546,7 +546,7 @@ class TileTest(test.TestCase):
self._RunAndVerifyResult(10, use_gpu=True)
def testGradientSimpleReduction(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 1], dtype=dtypes.float32)
@@ -561,7 +561,7 @@ class TileTest(test.TestCase):
self.assertAllClose(np.sum(grad_inp, axis=1).reshape(4, 1), result, 1e-3)
def testGradientStridedReduction(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -634,7 +634,7 @@ class TileTest(test.TestCase):
self._RunAndVerifyGradientResult([2, 1, 3, 3, 2], [1, 3, 3, 1, 2])
def testGradientStridedReductionGC(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -647,7 +647,7 @@ class TileTest(test.TestCase):
dtype=dtypes.float32)
outputs = array_ops.gather(array_ops.tile(inputs, [3]),
[1, 5, 9, 3, 7, 2, 2, 2])
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -659,7 +659,7 @@ class TileTest(test.TestCase):
inputs = array_ops.reshape(inputs, [-1, 1, 1])
outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]),
[1, 5, 9, 3, 7, 2, 2, 2])
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 4a1fc1d9a9..c08d3222b3 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -106,7 +107,7 @@ class SliceTest(test.TestCase):
def testScalarInput(self):
input_val = 0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
constant_op.constant(input_val)[:].get_shape()
@@ -120,7 +121,7 @@ class SliceTest(test.TestCase):
def testInvalidIndex(self):
input_val = [1, 2]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
constant_op.constant(input_val)[1:, 1:].get_shape()
@@ -260,6 +261,21 @@ class SliceTest(test.TestCase):
grad_actual = gradients_impl.gradients(out, inp)[0].eval()
self.assertAllClose([0., 1., 1.], grad_actual)
+ def _testGradientVariableSize2D(self):
+ # Regression test for bug in slice. A low-level bug in Eigen was causing
+ # incorrect results for negative indices in multi-dimensional tensors.
+ # See b/114318298.
+ with self.test_session(use_gpu=True) as sess:
+ x = constant_op.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 7]])
+ loss1 = math_ops.reduce_sum(x[:-1, :-1] * 1.0)
+ loss2 = math_ops.reduce_sum(x[:-1][:, :-1])
+
+ g1 = gradients_impl.gradients(loss1, x)[0]
+ g2 = gradients_impl.gradients(loss2, x)[0]
+
+ g1_val, g2_val = sess.run([g1, g2])
+ self.assertAllEqual(g1_val, g2_val)
+
def testGradientsAll(self):
# Slice the middle square out of a 4x4 input
self._testGradientSlice([4, 4], [1, 1], [2, 2])
@@ -276,6 +292,9 @@ class SliceTest(test.TestCase):
# Use -1 as a slice dimension.
self._testGradientVariableSize()
+ # Use -1 as a slice dimension on a 2D tensor.
+ self._testGradientVariableSize2D()
+
def testNotIterable(self):
# NOTE(mrry): If we register __getitem__ as an overloaded
# operator, Python will valiantly attempt to iterate over the
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index fbf1adba9b..e53347c4bc 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -210,7 +210,7 @@ class SoftmaxTest(test.TestCase):
self.assertEqual([3, 2, 4], op.get_shape())
def testEmptyInput(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[0, 3])
self.assertEqual(0, array_ops.size(x).eval())
# reshape would raise if logits is empty
@@ -218,7 +218,7 @@ class SoftmaxTest(test.TestCase):
nn_ops.softmax(x, axis=0).eval()
def testDimTooLarge(self):
- with self.test_session():
+ with self.cached_session():
# Use placeholder to make sure we get runtime error instead of shape
# inference error.
dim = array_ops.placeholder_with_default(100, shape=[])
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index c0269db9ae..afe3df6178 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -72,7 +72,7 @@ class SoftplusTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -88,7 +88,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -105,7 +105,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 5e-5)
def testGradGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -123,7 +123,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 5e-5)
def testNoInts(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"No OpKernel was registered to support Op 'Softplus'"):
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index a5247ce08d..05a7c53dee 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -51,7 +51,7 @@ class SoftsignTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -67,7 +67,7 @@ class SoftsignTest(test.TestCase):
self.assertLess(err, 1e-4)
def testNoInts(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"No OpKernel was registered to support Op 'Softsign'"):
diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
index 2a9232b6ae..e267c05915 100644
--- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
@@ -551,7 +551,7 @@ class SpaceToBatchNDGradientTest(test.TestCase):
def _checkGrad(self, x, block_shape, paddings):
block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = array_ops.space_to_batch_nd(tf_x, block_shape, paddings)
epsilon = 1e-5
@@ -638,7 +638,7 @@ class RequiredSpaceToBatchPaddingsTest(test.TestCase):
t_paddings, t_crops = array_ops.required_space_to_batch_paddings(
input_shape_placeholder, block_shape_placeholder,
base_paddings_placeholder)
- with self.test_session():
+ with self.cached_session():
paddings_result = t_paddings.eval(assignments)
crops_result = t_crops.eval(assignments)
self.assertAllEqual(paddings_result, paddings_const)
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index 3bb5e899fe..477720302d 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -99,20 +99,20 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q")
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorSetGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
set_global_step_op = q.set_global_step(1)
set_global_step_op.run()
def testAccumulatorApplyGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
accum_op = q.apply_indexed_slices_grad(
@@ -123,7 +123,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 1)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
for i in range(len(dtypes)):
@@ -145,7 +145,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self._assertEqual_nparray(sum_elems / len(elems), result, sess)
def testAccumulatorMultipleAccumulators(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_f32_0 = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
q_f32_1 = data_flow_ops.SparseConditionalAccumulator(
@@ -175,7 +175,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self._assertEqual_indexedslices(expected_tensors[i], result)
def testAccumulatorTakeGradMean(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -220,7 +220,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -258,7 +258,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.dense_shape, [-1, 2])
def testParallelApplyGradMean(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -323,7 +323,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
val, sess)
def testParallelTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [e + 1 for e in range(10)]
@@ -362,7 +362,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[0, 0], [elems[i], 0]]), results[i], sess)
def testAccumulatorApplyAndBlockingTake(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@@ -397,7 +397,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
sess.run(takeg_op)
def testAccumulatorCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -416,7 +416,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_thread.join()
def testNonVectorIndices(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -428,7 +428,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_values=np.array([1, 2]).astype(np.float32)).run()
def testZeroDimensionValues(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -438,7 +438,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_indices=[0], grad_values=np.array(1).astype(np.float32)).run()
def testWrongNonEmptyInputValues(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -449,7 +449,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_values=np.array([[0, 1, 1]]).astype(np.float32)).run()
def testDynamicNonVectorIndices(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -468,7 +468,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
})
def testDynamicWrongNonEmptyInputValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -486,7 +486,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
})
def testEmptyShapeApply(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([]))
@@ -511,7 +511,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q.apply_grad(grad_indices=[0], grad_values=[1.0]).run()
def testValidateShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=[2, 2, None])
@@ -606,7 +606,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
local_step=1).run()
def testReturnShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=[2, None])
@@ -631,7 +631,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.dense_shape, [-1, 2, 2, 3])
def testApplyGradtInt32IndicesAndShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
accum_op = q.apply_grad(
diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
index ca7898d466..6e0714da70 100644
--- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
@@ -42,7 +42,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_dense(self):
@@ -62,7 +62,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_sparse(self):
@@ -76,7 +76,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
'55555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_dense(self):
@@ -94,7 +94,7 @@ class SparseCrossOpTest(test.TestCase):
'55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
'999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_sparse_cross_dense(self):
@@ -111,7 +111,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_sparse_input(self):
@@ -127,7 +127,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
'5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x3x3(self):
@@ -169,7 +169,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x1x2(self):
@@ -188,7 +188,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_large_batch(self):
@@ -221,7 +221,7 @@ class SparseCrossOpTest(test.TestCase):
])
expected_out = self._sparse_tensor(col_out)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_one_column_empty(self):
@@ -234,7 +234,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([], 1),
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_some_columns_empty(self):
@@ -253,7 +253,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
]], 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_all_columns_empty(self):
@@ -266,7 +266,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([]),
self._sparse_tensor([])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_hashed_zero_bucket_no_hash_key(self):
@@ -277,7 +277,7 @@ class SparseCrossOpTest(test.TestCase):
])
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[1971693436396284976]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_zero_bucket(self):
@@ -290,7 +290,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[4847552627144134031]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
# TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -304,7 +304,7 @@ class SparseCrossOpTest(test.TestCase):
num_buckets=100)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[83]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output(self):
@@ -318,7 +318,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[31]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed__has_no_collision(self):
@@ -344,7 +344,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
],
num_buckets=1000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
out = sess.run(op)
self.assertEqual(6, len(out.values))
self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index f50e39d6d5..90009fc33e 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -130,7 +130,7 @@ class MatMulGradientTest(test.TestCase):
def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, delta,
name):
- with self.test_session():
+ with self.cached_session():
a = constant_op.constant(
RandMatrix(
3, 2, tr_a, round_bfloat=True), dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index fc39de150e..79efee3f5b 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -628,7 +628,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
else:
np_ans = np.max(np_ans, axis=ra, keepdims=keep_dims)
- with self.test_session():
+ with self.cached_session():
if do_sum:
tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes,
keep_dims)
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index 87a4eb9c7b..c71746cc99 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -81,7 +81,7 @@ class SparseToDenseTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self):
- with self.test_session():
+ with self.cached_session():
x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0])
@@ -94,12 +94,12 @@ class SparseToDenseTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
def testBadShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1)
def testBadValue(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [[5], [3]], -1)
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[2,1\], "
@@ -107,20 +107,20 @@ class SparseToDenseTest(test.TestCase):
dense.eval()
def testBadNumValues(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1)
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
dense.eval()
def testBadDefault(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [1, 2], [0])
with self.assertRaisesOpError("default_value should be a scalar"):
dense.eval()
def testOutOfBoundsIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[1], [10]],
output_size=[5],
@@ -140,7 +140,7 @@ class SparseToDenseTest(test.TestCase):
dense_without_validation.eval()
def testRepeatingIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[1], [1]],
output_size=[5],
@@ -158,7 +158,7 @@ class SparseToDenseTest(test.TestCase):
dense_without_validation.eval()
def testUnsortedIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[2], [1]],
output_size=[5],
diff --git a/tensorflow/python/kernel_tests/sparsemask_op_test.py b/tensorflow/python/kernel_tests/sparsemask_op_test.py
index cf6c9494ae..6f5dd45b61 100644
--- a/tensorflow/python/kernel_tests/sparsemask_op_test.py
+++ b/tensorflow/python/kernel_tests/sparsemask_op_test.py
@@ -34,7 +34,7 @@ class SparseMaskTest(test.TestCase):
out_values = values[1:, :]
out_indices = np.array([2, 3, 4], dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_tensor = ops.convert_to_tensor(values)
indices_tensor = ops.convert_to_tensor(indices)
mask_indices_tensor = ops.convert_to_tensor(mask_indices)
diff --git a/tensorflow/python/kernel_tests/string_join_op_test.py b/tensorflow/python/kernel_tests/string_join_op_test.py
index ce19333654..e4371ab5b9 100644
--- a/tensorflow/python/kernel_tests/string_join_op_test.py
+++ b/tensorflow/python/kernel_tests/string_join_op_test.py
@@ -28,7 +28,7 @@ class StringJoinOpTest(test.TestCase):
input1 = "a"
input2 = [["b"], ["c"]]
- with self.test_session():
+ with self.cached_session():
output = string_ops.string_join([input0, input1])
self.assertAllEqual(output.eval(), [b"aa", b"ba"])
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 075a3204ad..9f013c2c7e 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -27,7 +27,7 @@ class StringLengthOpTest(test.TestCase):
def testStringLength(self):
strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lengths = string_ops.string_length(strings)
values = sess.run(lengths)
self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index b6a0f45adc..b968e885ed 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -32,7 +32,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplit(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -42,7 +42,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitEmptyDelimiter(self):
strings = ["hello", "hola", b"\xF0\x9F\x98\x8E"] # Last string is U+1F60E
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, delimiter="")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
@@ -60,7 +60,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitEmptyToken(self):
strings = ["", " a", "b ", " c", " ", " d ", " e", "f ", " g ", " "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -72,7 +72,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitOnSetEmptyToken(self):
strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, delimiter=" .")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -84,7 +84,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimiter(self):
strings = ["hello|world", "hello world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertRaises(
ValueError, string_ops.string_split, strings, delimiter=["|", ""])
@@ -106,7 +106,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimiterTensor(self):
strings = ["hello|world", "hello world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
delimiter = array_ops.placeholder(dtypes.string)
tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -124,7 +124,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimitersTensor(self):
strings = ["hello.cruel,world", "hello cruel world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
delimiter = array_ops.placeholder(dtypes.string)
tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -143,7 +143,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithNoSkipEmpty(self):
strings = ["#a", "b#", "#c#"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, "#", skip_empty=False)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -152,7 +152,7 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
self.assertAllEqual(shape, [3, 3])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, "#")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(values, [b"a", b"b", b"c"])
@@ -165,7 +165,7 @@ class StringSplitV2OpTest(test.TestCase):
def testSplitV2(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -180,7 +180,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['', '', '4', '5', '', '6', '']
strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep="<>")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -198,7 +198,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['1', '2', '', '3', '']
strings = ["1,2,3", "4,5,,6,"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',')
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -215,7 +215,7 @@ class StringSplitV2OpTest(test.TestCase):
#['1', '2', '3']
strings = ["1 2 3", " 4 5 6 "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -231,7 +231,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5,,6,']
strings = ["1,2,3", "4,5,,6,"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -247,7 +247,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5 6 ']
strings = ["1 2 3", " 4 5 6 "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, maxsplit=1)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
diff --git a/tensorflow/python/kernel_tests/string_strip_op_test.py b/tensorflow/python/kernel_tests/string_strip_op_test.py
index 30fd477ff4..a96b71490e 100644
--- a/tensorflow/python/kernel_tests/string_strip_op_test.py
+++ b/tensorflow/python/kernel_tests/string_strip_op_test.py
@@ -28,7 +28,7 @@ class StringStripOpTest(test.TestCase):
def test_string_strip(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
@@ -37,7 +37,7 @@ class StringStripOpTest(test.TestCase):
strings = [["pigs on the wing", "animals"],
[" hello ", "\n\tworld \r \n"]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
@@ -46,7 +46,7 @@ class StringStripOpTest(test.TestCase):
def test_string_strip_with_empty_strings(self):
strings = [" hello ", "", "world ", " \t \r \n "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [b"hello", b"", b"world", b""])
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 2c6064e64b..9cb0c9d18f 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import test
class StringToHashBucketOpTest(test.TestCase):
def testStringToOneHashBucketFast(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket_fast(input_string, 1)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -35,7 +35,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([0, 0, 0], result)
def testStringToHashBucketsFast(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket_fast(input_string, 10)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c', 'd']})
@@ -47,7 +47,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([9, 2, 2, 5], result)
def testStringToOneHashBucketLegacyHash(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket(input_string, 1)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -55,7 +55,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([0, 0, 0], result)
def testStringToHashBucketsLegacyHash(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket(input_string, 10)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -66,14 +66,14 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([8, 0, 7], result)
def testStringToOneHashBucketStrongOneHashBucket(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
output = string_ops.string_to_hash_bucket_strong(
input_string, 1, key=[123, 345])
self.assertAllEqual([0, 0, 0], output.eval())
def testStringToHashBucketsStrong(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
output = string_ops.string_to_hash_bucket_strong(
input_string, 10, key=[98765, 132])
@@ -84,7 +84,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([4, 2, 8], output.eval())
def testStringToHashBucketsStrongInvalidKey(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
with self.assertRaisesOpError('Key must have 2 elements'):
string_ops.string_to_hash_bucket_strong(
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
index cc4c21b66c..99ee25e125 100644
--- a/tensorflow/python/kernel_tests/string_to_number_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -29,7 +29,7 @@ _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
class StringToNumberOpTest(test.TestCase):
def _test(self, tf_type, good_pairs, bad_pairs):
- with self.test_session():
+ with self.cached_session():
# Build a small testing graph.
input_string = array_ops.placeholder(dtypes.string)
output = parsing_ops.string_to_number(
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 73ac71e1f5..4d163a0f6f 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import errors_impl
@@ -25,7 +26,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class SubstrOpTest(test.TestCase):
+class SubstrOpTest(test.TestCase, parameterized.TestCase):
def _testScalarString(self, dtype):
test_string = b"Hello"
@@ -34,11 +35,22 @@ class SubstrOpTest(test.TestCase):
expected_value = b"ell"
substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
+ # Negative position.
+ test_string = b"Hello"
+ position = np.array(-4, dtype)
+ length = np.array(3, dtype)
+ expected_value = b"ell"
+
+ substr_op = string_ops.substr(test_string, position, length)
with self.test_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # position is equal to the length of string.
+ # Position is equal to the length of string.
test_string = b""
position = np.array(0, dtype)
length = np.array(2, dtype)
@@ -49,6 +61,17 @@ class SubstrOpTest(test.TestCase):
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
+ # Negative position magnitude is equal to the length of string.
+ test_string = b"yo"
+ position = np.array(-2, dtype)
+ length = np.array(1, dtype)
+ expected_value = b"y"
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
def _testVectorStrings(self, dtype):
test_string = [b"Hello", b"World"]
position = np.array(1, dtype)
@@ -60,6 +83,17 @@ class SubstrOpTest(test.TestCase):
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
+ # Negative position.
+ test_string = [b"Hello", b"World"]
+ position = np.array(-4, dtype)
+ length = np.array(3, dtype)
+ expected_value = [b"ell", b"orl"]
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
def _testMatrixStrings(self, dtype):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
@@ -74,17 +108,31 @@ class SubstrOpTest(test.TestCase):
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
+ # Negative position
+ test_string = [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]]
+ position = np.array(-2, dtype)
+ length = np.array(2, dtype)
+ expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
+ [b"en", b"en", b"en"]]
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
def _testElementWisePosLen(self, dtype):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
- length = np.array([[2, 3, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"eve", b"lve"], [b"hirt", b"urt", b"te"],
- [b"ixtee", b"vente", b"hteen"]]
+ position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
+ length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
+ expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -94,33 +142,33 @@ class SubstrOpTest(test.TestCase):
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"],
[b"nineteen", b"twenty", b"twentyone"]]
- position = np.array([1, 2, 3], dtype)
+ position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
- [b"i", b"ve", b"hte"], [b"i", b"en", b"nty"]]
+ expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
test_string = [b"thirteen", b"fourteen", b"fifteen"]
- position = np.array([[1, 2, 3], [3, 2, 1], [5, 5, 5]], dtype)
+ position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"ur", b"t"], [b"r", b"ur", b"ift"],
- [b"ee", b"ee", b"en"]]
+ expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
test_string = b"thirteen"
- position = np.array([1, 5, 7], dtype)
+ position = np.array([1, -5, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"ee", b"n"]
+ expected_value = [b"hir", b"rt", b"n"]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -128,10 +176,8 @@ class SubstrOpTest(test.TestCase):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([1, 2, 3, 4], dtype)
+ position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
- [b"i", b"ve", b"hte"]]
with self.assertRaises(ValueError):
substr_op = string_ops.substr(test_string, position, length)
@@ -141,6 +187,15 @@ class SubstrOpTest(test.TestCase):
position = np.array(7, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Scalar/Scalar (with negative)
+ test_string = b"Hello"
+ position = np.array(-7, dtype)
+ length = np.array(3, dtype)
+ substr_op = string_ops.substr(test_string, position, length)
with self.test_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -150,16 +205,16 @@ class SubstrOpTest(test.TestCase):
position = np.array(4, dtype)
length = np.array(1, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
- # Negative pos
- test_string = b"Hello"
- position = np.array(-1, dtype)
- length = np.array(3, dtype)
+ # Vector/Scalar (with negative)
+ test_string = [b"good", b"good", b"bad", b"good"]
+ position = np.array(-4, dtype)
+ length = np.array(1, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -169,6 +224,16 @@ class SubstrOpTest(test.TestCase):
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Matrix/Matrix (with negative)
+ test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]]
+ position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
+ length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
+ substr_op = string_ops.substr(test_string, position, length)
with self.test_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -178,6 +243,15 @@ class SubstrOpTest(test.TestCase):
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Broadcast (with negative)
+ test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ position = np.array([-1, -2, -4], dtype)
+ length = np.array([1, 2, 3], dtype)
+ substr_op = string_ops.substr(test_string, position, length)
with self.test_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -198,7 +272,18 @@ class SubstrOpTest(test.TestCase):
with self.assertRaises(ValueError):
substr_op = string_ops.substr(test_string, position, length)
- def _testAll(self, dtype):
+ # Negative position.
+ test_string = [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]]
+ position = np.array([[-1, -2, -3]], dtype)
+ length = np.array([1, 2, 3], dtype)
+ # Should fail: position/length have different rank
+ with self.assertRaises(ValueError):
+ substr_op = string_ops.substr(test_string, position, length)
+
+ @parameterized.parameters(np.int32, np.int64)
+ def testAll(self, dtype):
self._testScalarString(dtype)
self._testVectorStrings(dtype)
self._testMatrixStrings(dtype)
@@ -208,14 +293,8 @@ class SubstrOpTest(test.TestCase):
self._testOutOfRangeError(dtype)
self._testMismatchPosLenShapes(dtype)
- def testInt32(self):
- self._testAll(np.int32)
-
- def testInt64(self):
- self._testAll(np.int64)
-
def testWrongDtype(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3.0, 1)
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 2da7107f61..0c500120b0 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -34,7 +34,7 @@ class SummaryOpsTest(test.TestCase):
return summ
def testScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant([10.0, 20.0])
summ = logging_ops.scalar_summary(["c1", "c2"], const, name="mysumm")
value = sess.run(summ)
@@ -45,7 +45,7 @@ class SummaryOpsTest(test.TestCase):
""", self._AsSummary(value))
def testScalarSummaryDefaultName(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant([10.0, 20.0])
summ = logging_ops.scalar_summary(["c1", "c2"], const)
value = sess.run(summ)
@@ -56,7 +56,7 @@ class SummaryOpsTest(test.TestCase):
""", self._AsSummary(value))
def testMergeSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(10.0)
summ1 = summary.histogram("h", const)
summ2 = logging_ops.scalar_summary("c", const)
diff --git a/tensorflow/python/kernel_tests/summary_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
index d534aadb79..0f4643393a 100644
--- a/tensorflow/python/kernel_tests/summary_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
@@ -42,7 +42,7 @@ class SummaryOpsTest(test.TestCase):
self.assertTrue(np.array_equal(actual, expected))
def testTags(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(1)
s1 = summary_ops.tensor_summary("s1", c)
with ops.name_scope("foo"):
@@ -65,7 +65,7 @@ class SummaryOpsTest(test.TestCase):
self.assertEqual(v4.tag, "foo/zod/TensorSummary")
def testScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(10.0)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -76,7 +76,7 @@ class SummaryOpsTest(test.TestCase):
def testStringSummary(self):
s = six.b("foobar")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(s)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -86,7 +86,7 @@ class SummaryOpsTest(test.TestCase):
self._AssertNumpyEq(n, s)
def testManyScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = array_ops.ones([5, 5, 5])
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -96,7 +96,7 @@ class SummaryOpsTest(test.TestCase):
def testManyStringSummary(self):
strings = [[six.b("foo bar"), six.b("baz")], [six.b("zoink"), six.b("zod")]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(strings)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -106,7 +106,7 @@ class SummaryOpsTest(test.TestCase):
def testManyBools(self):
bools = [True, True, True, False, False, False]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(bools)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -116,7 +116,7 @@ class SummaryOpsTest(test.TestCase):
self._AssertNumpyEq(n, bools)
def testSummaryDescriptionAndDisplayName(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def get_description(summary_op):
summ_str = sess.run(summary_op)
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index 8ad29afd0a..d8d76440f1 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -48,7 +48,7 @@ class TensordotTest(test_lib.TestCase):
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, (a_axes, b_axes))
# Invalid dynamic shapes.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Matrix size-incompatible"):
a_ph = array_ops.placeholder(dtypes.float32)
@@ -80,7 +80,7 @@ class TensordotTest(test_lib.TestCase):
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
# Note: We don't support scalar Tensor values for axes.
for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = sess.run(
[output], feed_dict={
@@ -92,7 +92,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
for axes_value in [1, 2], [[1], [2]], [[], []], 0:
- with self.test_session() as sess:
+ with self.cached_session():
np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None]
np_ans = np.tensordot(np_a, np_b, axes_value)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 290200ce45..f42800226e 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -451,13 +451,13 @@ class TransposeTest(test.TestCase):
array_ops.transpose(array_ops.placeholder(dtypes.int32)).get_shape())
def testNullTensor(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([], dtype=dtypes.float32, shape=[1, 4, 0])
xt = array_ops.transpose(x, [0, 2, 1]).eval()
self.assertAllEqual(xt.shape, (1, 0, 4))
def _testError(self, x, p, err):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(err):
array_ops.transpose(x, p).eval()
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index bbc040dc13..316570e13e 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -30,7 +30,7 @@ class UniqueTest(test.TestCase):
def testInt32(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x)
tf_y, tf_idx = sess.run([y, idx])
@@ -41,7 +41,7 @@ class UniqueTest(test.TestCase):
def testInt32OutIdxInt64(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x, out_idx=dtypes.int64)
tf_y, tf_idx = sess.run([y, idx])
@@ -53,7 +53,7 @@ class UniqueTest(test.TestCase):
def testString(self):
indx = np.random.randint(65, high=122, size=7000)
x = [chr(i) for i in indx]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x)
tf_y, tf_idx = sess.run([y, idx])
@@ -65,7 +65,7 @@ class UniqueTest(test.TestCase):
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
tf_y0, tf_idx0 = sess.run([y0, idx0])
y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
@@ -79,7 +79,7 @@ class UniqueTest(test.TestCase):
# This test is only temporary, once V2 is used
# by default, the axis will be wrapped to allow `axis=None`.
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32))
tf_y, tf_idx = sess.run([y, idx])
@@ -93,7 +93,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -106,7 +106,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32OutIdxInt64(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -121,7 +121,7 @@ class UniqueWithCountsTest(test.TestCase):
indx = np.random.randint(65, high=122, size=7000)
x = [chr(i) for i in indx]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -136,7 +136,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([0], dtype))
tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
@@ -154,7 +154,7 @@ class UniqueWithCountsTest(test.TestCase):
# This test is only temporary, once V2 is used
# by default, the axis will be wrapped to allow `axis=None`.
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([], np.int32))
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 1ee6e0866a..b373c419b6 100644
--- a/tensorflow/python/kernel_tests/unstack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -99,7 +99,7 @@ class UnstackOpTest(test.TestCase):
self.assertLess(err, 1e-6)
def testInferNum(self):
- with self.test_session():
+ with self.cached_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
x = array_ops.placeholder(np.float32, shape=shape)
cs = array_ops.unstack(x)
@@ -131,13 +131,13 @@ class UnstackOpTest(test.TestCase):
for j in range(-i, i):
expected = np_split_squeeze(a, j)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_unstack = sess.run(array_ops.unstack(a, axis=j))
self.assertAllEqual(expected, actual_unstack)
def testAxis0Default(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
unstacked = sess.run(array_ops.unstack(a))
@@ -156,7 +156,7 @@ class UnstackOpTest(test.TestCase):
array_ops.unstack(a, axis=-3)
def testZeroLengthDim(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros(shape=(0, 1, 2))
y = array_ops.unstack(x, axis=1)[0].eval()
self.assertEqual(y.shape, (0, 2))
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
index cf369c0718..3d2f8b6155 100644
--- a/tensorflow/python/kernel_tests/variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -118,7 +118,7 @@ class VariableOpTest(test.TestCase):
self.assertEqual(tensor_shape.unknown_shape(), assigned.get_shape())
def testAssignNoShape(self):
- with self.test_session():
+ with self.cached_session():
value = self._NewShapelessTensor()
var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
@@ -126,7 +126,7 @@ class VariableOpTest(test.TestCase):
state_ops.assign(var, value).get_shape())
def testAssignNoShapeNoValidateShape(self):
- with self.test_session():
+ with self.cached_session():
value = self._NewShapelessTensor()
var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index d57b79cb90..401e1ae102 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -113,7 +113,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(w.constraint, constraint)
def testStringDefaultInitializer(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
variables_lib.global_variables_initializer().run()
self.assertAllEqual(compat.as_bytes(v.eval()), b"")
@@ -263,7 +263,7 @@ class VariableScopeTest(test.TestCase):
# TODO(alive): support variable partitioning/caching in eager mode.
def testVarScopeCachingDevice(self):
- with self.test_session():
+ with self.cached_session():
caching_device = "/job:moo"
with variable_scope.variable_scope("tower"):
with variable_scope.variable_scope(
@@ -367,7 +367,7 @@ class VariableScopeTest(test.TestCase):
variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
def testControlDeps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variable_scope.get_variable(
"v0", [1], initializer=init_ops.constant_initializer(0))
with ops.control_dependencies([v0.value()]):
@@ -403,7 +403,7 @@ class VariableScopeTest(test.TestCase):
variable_scope._DEFAULT_USE_RESOURCE = old
def testControlFlow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variable_scope.get_variable(
"v0", [], initializer=init_ops.constant_initializer(0))
var_dict = {}
@@ -513,7 +513,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
def testVarScopeOriginalNameScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("scope1"):
with variable_scope.variable_scope("tower") as tower:
self.assertEqual(tower.original_name_scope, "scope1/tower/")
@@ -536,7 +536,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc3, "scope1/tower/bar_1/")
def testVarScopeObjectReuse(self):
- with self.test_session():
+ with self.cached_session():
vs = None
with variable_scope.variable_scope("jump", reuse=True) as scope:
vs = scope
@@ -563,7 +563,7 @@ class VariableScopeTest(test.TestCase):
self.assertFalse(jump_no_reuse.reuse)
def testVarScopeGetOrCreateReuse(self):
- with self.test_session():
+ with self.cached_session():
def test_value(value):
x = constant_op.constant(value)
@@ -582,7 +582,7 @@ class VariableScopeTest(test.TestCase):
test_value(17.)
def testVarOpScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("testVarOpScope1"):
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -608,7 +608,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(None, "defaultScope1"):
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
@@ -631,7 +631,7 @@ class VariableScopeTest(test.TestCase):
"defaultScope1_2/layer/w:0")
def testVarOpScopeUniqueNamesWithJump(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("default") as default:
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
@@ -647,7 +647,7 @@ class VariableScopeTest(test.TestCase):
variable_scope.get_variable("w", []).name, "default/layer_2/w:0")
def testVarOpScopeReuse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -673,7 +673,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testVarScopeGetVar(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("root"):
with variable_scope.variable_scope("towerA") as tower_a:
va = variable_scope.get_variable("v", [1])
@@ -719,7 +719,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual("dtype" in str(exc.exception), True)
def testVarScopeOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
pass
with variable_scope.variable_scope(outer):
@@ -743,7 +743,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_2/default/scope2/")
def testVarScopeNestedOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope(outer):
self.assertEqual(
@@ -768,7 +768,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer/default_1/scope2/")
def testVarOpScopeReuseParam(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -795,14 +795,14 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testVarOpScopeReuseError(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
with variable_scope.variable_scope(None, "default", reuse=True):
self.assertEqual(
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
def testVarOpScopeOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
pass
with variable_scope.variable_scope(outer, "default", []):
@@ -827,7 +827,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_2/default/scope2/")
def testVarOpScopeNestedOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope(outer, "default", []):
self.assertEqual(
@@ -851,7 +851,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testBasicWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"scope", auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
@@ -886,7 +886,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c").name, "outer/inner/c:0")
def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
None, default_name="default", auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
@@ -910,7 +910,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c").name, "outer/default/c:0")
def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
root_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(
root_scope, auxiliary_name_scope=False) as scope:
@@ -927,7 +927,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c1").name, "outer/c1:0")
def testAuxiliaryNameScopeIsInvalid(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
with variable_scope.variable_scope(
None, default_name="scope", auxiliary_name_scope="invalid"):
@@ -947,7 +947,7 @@ class VariableScopeTest(test.TestCase):
def testReuseScopeWithoutNameScopeCollision(self):
# Github issue: #13429
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope("inner") as inner:
pass
@@ -1021,7 +1021,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(varname_type[1], ("y", dtypes.int64))
def testGetCollection(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetCollection_a", [])
_ = variable_scope.get_variable(
"testGetCollection_b", [], trainable=False)
@@ -1075,7 +1075,7 @@ class VariableScopeTest(test.TestCase):
])
def testGetTrainableVariablesWithGetVariable(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
with variable_scope.variable_scope(
"testGetTrainableVariables_foo") as scope:
@@ -1111,7 +1111,7 @@ class VariableScopeTest(test.TestCase):
trainable=True)
def testGetTrainableVariablesWithVariable(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
with variable_scope.variable_scope(
"testGetTrainableVariables_foo") as scope:
@@ -1150,7 +1150,7 @@ class VariableScopeTest(test.TestCase):
trainable=True)
def testGetGlobalVariables(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetGlobalVariables_a", [])
with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
_ = variable_scope.get_variable("testGetGlobalVariables_b", [])
@@ -1160,7 +1160,7 @@ class VariableScopeTest(test.TestCase):
"testGetGlobalVariables_b:0"])
def testGetLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable(
"a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
with variable_scope.variable_scope("foo") as scope:
@@ -1396,7 +1396,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual("scope/v/0:0", true_vars[0].name)
self.assertEqual("scope/v/1:0", true_vars[1].name)
self.assertEqual("custom_getter/add:0", v.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
np_vars, np_v = sess.run([true_vars, v])
self.assertAllClose(np_v, sum(np_vars))
@@ -1436,7 +1436,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(template % (1, 1, 0), true_vars[6].name)
self.assertEqual(template % (1, 1, 1), true_vars[7].name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
np_vars, np_v = sess.run([true_vars, v])
# take products of sums of products
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2b9c62ad6f..2e7975667c 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -42,7 +42,7 @@ from tensorflow.python.util import compat
class VariablesTestCase(test.TestCase):
def testInitialization(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
@@ -69,7 +69,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(1.1, var1.eval())
def testInitializationOrder(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
self.assertEqual("rnd:0", rnd.name)
self.assertEqual([3, 6], rnd.get_shape())
@@ -106,7 +106,7 @@ class VariablesTestCase(test.TestCase):
pass
def testAssignments(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(0.0)
plus_one = var.assign_add(1.0)
minus_one = var.assign_sub(2.0)
@@ -142,7 +142,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(4.0, var.eval())
def testZeroSizeStringAssign(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
array = variables.Variable(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
@@ -154,7 +154,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], list(sess.run(copy_op)))
def _countUpToTest(self, dtype):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0, dtype=dtype)
var = variables.Variable(zero)
count_up_to = var.count_up_to(3)
@@ -186,7 +186,7 @@ class VariablesTestCase(test.TestCase):
self._countUpToTest(dtypes.int64)
def testControlDepsNone(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dep.
@@ -199,7 +199,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access
def testControlFlow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(0, name="v0")
var_dict = {}
@@ -248,7 +248,7 @@ class VariablesTestCase(test.TestCase):
control_flow_ops.while_loop(cond, body, [0, 0])
def testUseVariableAsTensor(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(3.0)
variables.global_variables_initializer().run()
@@ -257,7 +257,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval())
def testZeroSizeVarSameAsConst(self):
- with self.test_session():
+ with self.cached_session():
zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
zero_size_const = array_ops.ones([2, 0])
variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
@@ -269,7 +269,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
def testCachingDevice(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(2.0)
self.assertEqual(var.device, var.value().device)
self.assertEqual(var.device, var.initialized_value().device)
@@ -279,7 +279,7 @@ class VariablesTestCase(test.TestCase):
self.assertTrue(var_cached.value().device.startswith("/job:foo"))
def testCollections(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(2.0, trainable=False)
var_z = variables.Variable(2.0, trainable=True)
@@ -294,7 +294,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
def testCollectionsWithScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("scope_1"):
var_x = variables.Variable(2.0)
with ops.name_scope("scope_2"):
@@ -309,7 +309,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_y], variables.trainable_variables("scope_2"))
def testOperators(self):
- with self.test_session():
+ with self.cached_session():
var_f = variables.Variable([2.0])
add = var_f + 0.0
radd = 1.0 + var_f
@@ -382,13 +382,13 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval())
def testSession(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1, 12])
variables.global_variables_initializer().run()
self.assertAllClose([1, 12], sess.run(var))
def testDevicePlacement(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("/cpu:0"):
var = variables.Variable([1, 12])
init_value = var.initialized_value()
@@ -408,7 +408,7 @@ class VariablesTestCase(test.TestCase):
def testInitializerFunction(self):
value = [[-42], [133.7]]
shape = [2, 1]
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(value)
v1 = variables.Variable(initializer, dtype=dtypes.float32)
@@ -443,7 +443,7 @@ class VariablesTestCase(test.TestCase):
constraint=constraint)
def testNoRefDataRace(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
b = variables.Variable(a.initialized_value() + 2)
c = variables.Variable(b.initialized_value() + 2)
@@ -453,7 +453,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllEqual(c.eval(), [5, 6, 7])
def testInitializerFunctionDevicePlacement(self):
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(42.0)
with ops.device("/cpu:100"):
v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
@@ -471,11 +471,11 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(expected_group_v2, i.op.colocation_groups())
def testVariableDefInitializedInstances(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v_def = variables.Variable(
initial_value=constant_op.constant(3.0)).to_proto()
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = variables.Variable(variable_def=v_def)
self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -486,7 +486,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# Restoring a legacy VariableDef proto that does not have
# initial_value_name set should still work.
v = variables.Variable(variable_def=v_def)
@@ -514,7 +514,7 @@ class VariablesTestCase(test.TestCase):
.trainable)
def testLoad(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
variables.global_variables_initializer().run()
var.load(np.ones((5, 5), np.float32))
@@ -540,12 +540,12 @@ class VariablesTestCase(test.TestCase):
class IsInitializedTest(test.TestCase):
def testNoVars(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
uninited = variables.report_uninitialized_variables()
self.assertEqual(0, sess.run(uninited).size)
def testAssertVariablesInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
_ = v, w
@@ -555,7 +555,7 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
@@ -566,14 +566,14 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testZeroSizeVarInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable(array_ops.zeros([0, 2]), name="v")
uninited = variables.report_uninitialized_variables()
v.initializer.run() # not strictly necessary
self.assertEqual(0, sess.run(uninited).size)
def testTrainingWithZeroSizeVar(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
a = variables.Variable(array_ops.zeros([0, 2]))
b = variables.Variable(array_ops.ones([2, 2]))
objective = math_ops.reduce_sum(b + math_ops.matmul(
@@ -592,7 +592,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
self.assertEqual(None, variables.assert_variables_initialized())
def testVariables(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
_ = v, w
@@ -603,7 +603,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
sess.run(inited)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
inited = variables.assert_variables_initialized([v])
diff --git a/tensorflow/python/kernel_tests/weights_broadcast_test.py b/tensorflow/python/kernel_tests/weights_broadcast_test.py
index eda2856e0b..85f9abc69f 100644
--- a/tensorflow/python/kernel_tests/weights_broadcast_test.py
+++ b/tensorflow/python/kernel_tests/weights_broadcast_test.py
@@ -44,7 +44,7 @@ class AssertBroadcastableTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.assert_broadcastable(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
static_op.run()
dynamic_op.run(feed_dict={
weights_placeholder: weights,
@@ -100,7 +100,7 @@ class AssertBroadcastableTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.assert_broadcastable(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
dynamic_op.run(feed_dict={
weights_placeholder: weights,
@@ -157,7 +157,7 @@ class BroadcastWeightsTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.broadcast_weights(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected, static_op.eval())
self.assertAllEqual(expected, dynamic_op.eval(feed_dict={
weights_placeholder: weights,
@@ -227,7 +227,7 @@ class BroadcastWeightsTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.broadcast_weights(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
dynamic_op.eval(feed_dict={
weights_placeholder: weights,
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 60c726d54c..729885169e 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -153,13 +153,13 @@ class XentTest(test.TestCase):
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
def testShapeMismatch(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
gen_nn_ops.softmax_cross_entropy_with_logits(
[[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
def testNotMatrix(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
gen_nn_ops.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
[0., 1., 0., 1.])
@@ -180,7 +180,7 @@ class XentTest(test.TestCase):
np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
def testGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l = constant_op.constant(
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
shape=[3, 4],
@@ -207,7 +207,7 @@ class XentTest(test.TestCase):
self.assertLess(err, 5e-8)
def testGradientLabelWithV2(self):
- with self.test_session():
+ with self.cached_session():
l = constant_op.constant(
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
shape=[3, 4],
@@ -225,7 +225,7 @@ class XentTest(test.TestCase):
self.assertLess(err, 5e-8)
def testSecondGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l = constant_op.constant(
[
0.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, 0.0, 0.0, 0.0, 0.0, 0.5 / 3, 0.0,
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 6ae869b89e..ade86e85bf 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
indices = op.inputs[1]
updates_grad = array_ops.gather_nd(grad, indices)
return [grad, None, updates_grad]
+
+
+@ops.RegisterGradient("BroadcastTo")
+def _BroadcastToGrad(op, grad):
+ input_value = op.inputs[0]
+ broadcast_shape = op.inputs[1]
+ # Assign ids for each position in input_value.
+ input_value_shape = array_ops.shape(input_value)
+ input_value_size = array_ops.size(input_value)
+ ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
+ broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
+ # Group by ids and sum its gradients.
+ grad_flatten = array_ops.reshape(grad, [-1])
+ broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
+ updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
+ broadcast_ids_flatten,
+ input_value_size)
+ updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
+ return [updates_grad, None]
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index f7cbfe0312..720f9f4d41 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -24,11 +24,17 @@ from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 6528062f3c..c3cf6e61f2 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -1292,3 +1292,9 @@ def ensure_shape(x, shape, name=None):
shape = tensor_shape.TensorShape(shape)
return array_ops.ensure_shape(x, shape, name=name)
+
+
+@ops.RegisterGradient('EnsureShape')
+def _ensure_shape_grad(op, grad):
+ del op # Unused.
+ return grad
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e3c1aa3d5a..0e20fadb2b 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,7 +61,7 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
-_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -2026,7 +2026,7 @@ def cond(pred,
```
"""
- if _ENABLE_COND_V2:
+ if ENABLE_COND_V2 and not context.executing_eagerly():
return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 908e793902..32d455bdad 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
- only the first of these is emitted. That is, when the top path
- is `A B B B B`, the return value is:
+ only the first of these is emitted. That is, when the sequence is
+ `A B B * B * B` (where '*' is the blank label), the return value is:
* `A B` if `merge_repeated = True`.
- * `A B B B B` if `merge_repeated = False`.
+ * `A B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index b65e64d401..2e7aa30296 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -1011,12 +1011,6 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
-
- if not self.is_constant_jacobian:
- return math_ops.reduce_sum(
- ildj,
- self._get_event_reduce_dims(min_event_ndims, event_ndims))
-
# In this case, we need to tile the Jacobian over the event and reduce.
y_rank = array_ops.rank(y)
y_shape = array_ops.shape(y)[
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index dd25fce2ec..fbbacf2521 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -69,7 +69,7 @@ class Categorical(distribution.Distribution):
The Categorical distribution is closely related to the `OneHotCategorical` and
`Multinomial` distributions. The Categorical distribution can be intuited as
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
- being identical to `argmax{ Multinomial(probs, total_count=1) }.
+ being identical to `argmax{ Multinomial(probs, total_count=1) }`.
#### Mathematical Details
@@ -83,7 +83,7 @@ class Categorical(distribution.Distribution):
The number of classes, `K`, must not exceed:
- the largest integer representable by `self.dtype`, i.e.,
- `2**(mantissa_bits+1)` (IEE754),
+ `2**(mantissa_bits+1)` (IEEE 754),
- the maximum `Tensor` index, i.e., `2**31-1`.
In other words,
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 3268b38b86..196161c661 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys,
"Gradient type %s generated for complex-valued "
"tensor %s with type %s must be real" % (dtypes.as_dtype(
grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
raise TypeError(
"Tensor %s with type %s must be numeric "
@@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor):
if _IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+ return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
def _VerifyGeneratedGradients(grads, op):
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3759d8a543..6243be6c9e 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@@ -1004,5 +1005,25 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
self._assert_indexed_slices_equal(total, result)
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+ def testDefaultGradYs(self):
+ with ops.Graph().as_default():
+ tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ a = constant(1.0)
+ tl = list_ops.tensor_list_push_back(tl, a)
+
+ grad_tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+ grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+ with self.cached_session() as sess:
+ self.assertEquals(sess.run(grad), 5.)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py
new file mode 100644
index 0000000000..86130a2c07
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_addition.py
@@ -0,0 +1,432 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Add one or more `LinearOperators` efficiently."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_diag
+from tensorflow.python.ops.linalg import linear_operator_full_matrix
+from tensorflow.python.ops.linalg import linear_operator_identity
+from tensorflow.python.ops.linalg import linear_operator_lower_triangular
+
+__all__ = []
+
+
+def add_operators(operators,
+ operator_name=None,
+ addition_tiers=None,
+ name=None):
+ """Efficiently add one or more linear operators.
+
+ Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
+ operators `[B1, B2,...]` such that
+
+ ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
+
+ The operators `Bk` result by adding some of the `Ak`, as allowed by
+ `addition_tiers`.
+
+ Example of efficient adding of diagonal operators.
+
+ ```python
+ A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
+ A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
+
+ # Use two tiers, the first contains an Adder that returns Diag. Since both
+ # A1 and A2 are Diag, they can use this Adder. The second tier will not be
+ # used.
+ addition_tiers = [
+ [_AddAndReturnDiag()],
+ [_AddAndReturnMatrix()]]
+ B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
+
+ len(B_list)
+ ==> 1
+
+ B_list[0].__class__.__name__
+ ==> 'LinearOperatorDiag'
+
+ B_list[0].to_dense()
+ ==> [[3., 0.],
+ [0., 3.]]
+
+ B_list[0].name
+ ==> 'Add/A1__A2/'
+ ```
+
+ Args:
+ operators: Iterable of `LinearOperator` objects with same `dtype`, domain
+ and range dimensions, and broadcastable batch shapes.
+ operator_name: String name for returned `LinearOperator`. Defaults to
+ concatenation of "Add/A__B/" that indicates the order of addition steps.
+ addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
+ is a list of `Adder` objects. This function attempts to do all additions
+ in tier `i` before trying tier `i + 1`.
+ name: A name for this `Op`. Defaults to `add_operators`.
+
+ Returns:
+ Subclass of `LinearOperator`. Class and order of addition may change as new
+ (and better) addition strategies emerge.
+
+ Raises:
+ ValueError: If `operators` argument is empty.
+ ValueError: If shapes are incompatible.
+ """
+ # Default setting
+ if addition_tiers is None:
+ addition_tiers = _DEFAULT_ADDITION_TIERS
+
+ # Argument checking.
+ check_ops.assert_proper_iterable(operators)
+ operators = list(reversed(operators))
+ if len(operators) < 1:
+ raise ValueError(
+ "Argument 'operators' must contain at least one operator. "
+ "Found: %s" % operators)
+ if not all(
+ isinstance(op, linear_operator.LinearOperator) for op in operators):
+ raise TypeError(
+ "Argument 'operators' must contain only LinearOperator instances. "
+ "Found: %s" % operators)
+ _static_check_for_same_dimensions(operators)
+ _static_check_for_broadcastable_batch_shape(operators)
+
+ graph_parents = []
+ for operator in operators:
+ graph_parents.extend(operator.graph_parents)
+
+ with ops.name_scope(name or "add_operators", values=graph_parents):
+
+ # Additions done in one of the tiers. Try tier 0, 1,...
+ ops_to_try_at_next_tier = list(operators)
+ for tier in addition_tiers:
+ ops_to_try_at_this_tier = ops_to_try_at_next_tier
+ ops_to_try_at_next_tier = []
+ while ops_to_try_at_this_tier:
+ op1 = ops_to_try_at_this_tier.pop()
+ op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
+ if op2 is not None:
+ # Will try to add the result of this again at this same tier.
+ new_operator = adder.add(op1, op2, operator_name)
+ ops_to_try_at_this_tier.append(new_operator)
+ else:
+ ops_to_try_at_next_tier.append(op1)
+
+ return ops_to_try_at_next_tier
+
+
+def _pop_a_match_at_tier(op1, operator_list, tier):
+ # Search from the back of list to the front in order to create nice default
+ # order of operations.
+ for i in range(1, len(operator_list) + 1):
+ op2 = operator_list[-i]
+ for adder in tier:
+ if adder.can_add(op1, op2):
+ return operator_list.pop(-i), adder
+ return None, None
+
+
+def _infer_hints_allowing_override(op1, op2, hints):
+ """Infer hints from op1 and op2. hints argument is an override.
+
+ Args:
+ op1: LinearOperator
+ op2: LinearOperator
+ hints: _Hints object holding "is_X" boolean hints to use for returned
+ operator.
+ If some hint is None, try to set using op1 and op2. If the
+ hint is provided, ignore op1 and op2 hints. This allows an override
+ of previous hints, but does not allow forbidden hints (e.g. you still
+ cannot say a real diagonal operator is not self-adjoint.
+
+ Returns:
+ _Hints object.
+ """
+ hints = hints or _Hints()
+ # If A, B are self-adjoint, then so is A + B.
+ if hints.is_self_adjoint is None:
+ is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
+ else:
+ is_self_adjoint = hints.is_self_adjoint
+
+ # If A, B are positive definite, then so is A + B.
+ if hints.is_positive_definite is None:
+ is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
+ else:
+ is_positive_definite = hints.is_positive_definite
+
+ # A positive definite operator is always non-singular.
+ if is_positive_definite and hints.is_positive_definite is None:
+ is_non_singular = True
+ else:
+ is_non_singular = hints.is_non_singular
+
+ return _Hints(
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite)
+
+
+def _static_check_for_same_dimensions(operators):
+ """ValueError if operators determined to have different dimensions."""
+ if len(operators) < 2:
+ return
+
+ domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
+ if op.domain_dimension.value is not None]
+ if len(set(value for name, value in domain_dimensions)) > 1:
+ raise ValueError("Operators must have the same domain dimension. Found: %s"
+ % domain_dimensions)
+
+ range_dimensions = [(op.name, op.range_dimension.value) for op in operators
+ if op.range_dimension.value is not None]
+ if len(set(value for name, value in range_dimensions)) > 1:
+ raise ValueError("Operators must have the same range dimension. Found: %s" %
+ range_dimensions)
+
+
+def _static_check_for_broadcastable_batch_shape(operators):
+ """ValueError if operators determined to have non-broadcastable shapes."""
+ if len(operators) < 2:
+ return
+
+ # This will fail if they cannot be broadcast together.
+ batch_shape = operators[0].batch_shape
+ for op in operators[1:]:
+ batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
+
+
+class _Hints(object):
+ """Holds 'is_X' flags that every LinearOperator is initialized with."""
+
+ def __init__(self,
+ is_non_singular=None,
+ is_positive_definite=None,
+ is_self_adjoint=None):
+ self.is_non_singular = is_non_singular
+ self.is_positive_definite = is_positive_definite
+ self.is_self_adjoint = is_self_adjoint
+
+
+################################################################################
+# Classes to add two linear operators.
+################################################################################
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _Adder(object):
+ """Abstract base class to add two operators.
+
+ Each `Adder` acts independently, adding everything it can, paying no attention
+ as to whether another `Adder` could have done the addition more efficiently.
+ """
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @abc.abstractmethod
+ def can_add(self, op1, op2):
+ """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
+ pass
+
+ @abc.abstractmethod
+ def _add(self, op1, op2, operator_name, hints):
+ # Derived classes can assume op1 and op2 have been validated, e.g. they have
+ # the same dtype, and their domain/range dimensions match.
+ pass
+
+ def add(self, op1, op2, operator_name, hints=None):
+ """Return new `LinearOperator` acting like `op1 + op2`.
+
+ Args:
+ op1: `LinearOperator`
+ op2: `LinearOperator`, with `shape` and `dtype` such that adding to
+ `op1` is allowed.
+ operator_name: `String` name to give to returned `LinearOperator`
+ hints: `_Hints` object. Returned `LinearOperator` will be created with
+ these hints.
+
+ Returns:
+ `LinearOperator`
+ """
+ updated_hints = _infer_hints_allowing_override(op1, op2, hints)
+
+ if operator_name is None:
+ operator_name = "Add/" + op1.name + "__" + op2.name + "/"
+
+ values = op1.graph_parents + op2.graph_parents
+ scope_name = self.name
+ if scope_name.startswith("_"):
+ scope_name = scope_name[1:]
+ with ops.name_scope(scope_name, values=values):
+ return self._add(op1, op2, operator_name, updated_hints)
+
+
+class _AddAndReturnScaledIdentity(_Adder):
+ """Handles additions resulting in an Identity family member.
+
+ The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
+ is closed under addition. This `Adder` respects that, and returns an Identity
+ """
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_IDENTITY_FAMILY)
+
+ def _add(self, op1, op2, operator_name, hints):
+ # Will build a LinearOperatorScaledIdentity.
+
+ if _type(op1) == _SCALED_IDENTITY:
+ multiplier_1 = op1.multiplier
+ else:
+ multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
+
+ if _type(op2) == _SCALED_IDENTITY:
+ multiplier_2 = op2.multiplier
+ else:
+ multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
+
+ return linear_operator_identity.LinearOperatorScaledIdentity(
+ num_rows=op1.range_dimension_tensor(),
+ multiplier=multiplier_1 + multiplier_2,
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnDiag(_Adder):
+ """Handles additions resulting in a Diag operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE)
+
+ def _add(self, op1, op2, operator_name, hints):
+ return linear_operator_diag.LinearOperatorDiag(
+ diag=op1.diag_part() + op2.diag_part(),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnTriL(_Adder):
+ """Handles additions resulting in a TriL operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE.union({_TRIL}))
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+
+ return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
+ tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnMatrix(_Adder):
+ """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
+
+ def can_add(self, op1, op2): # pylint: disable=unused-argument
+ return isinstance(op1, linear_operator.LinearOperator) and isinstance(
+ op2, linear_operator.LinearOperator)
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+ return linear_operator_full_matrix.LinearOperatorFullMatrix(
+ matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+################################################################################
+# Constants designating types of LinearOperators
+################################################################################
+
+# Type name constants for LinearOperator classes.
+_IDENTITY = "identity"
+_SCALED_IDENTITY = "scaled_identity"
+_DIAG = "diag"
+_TRIL = "tril"
+_MATRIX = "matrix"
+
+# Groups of operators.
+_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
+_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
+# operators with an efficient .add_to_tensor() method.
+_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
+
+
+def _type(operator):
+ """Returns the type name constant (e.g. _TRIL) for operator."""
+ if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
+ return _DIAG
+ if isinstance(operator,
+ linear_operator_lower_triangular.LinearOperatorLowerTriangular):
+ return _TRIL
+ if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
+ return _MATRIX
+ if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
+ return _IDENTITY
+ if isinstance(operator,
+ linear_operator_identity.LinearOperatorScaledIdentity):
+ return _SCALED_IDENTITY
+ raise TypeError("Operator type unknown: %s" % operator)
+
+
+################################################################################
+# Addition tiers:
+# We attempt to use Adders in tier K before K+1.
+#
+# Organize tiers to
+# (i) reduce O(..) complexity of forming final operator, and
+# (ii) produce the "most efficient" final operator.
+# Dev notes:
+# * Results of addition at tier K will be added at tier K or higher.
+# * Tiers may change, and we warn the user that it may change.
+################################################################################
+
+# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
+# dense matrix. So it is sometimes very inefficient.
+_DEFAULT_ADDITION_TIERS = [
+ [_AddAndReturnScaledIdentity()],
+ [_AddAndReturnDiag()],
+ [_AddAndReturnTriL()],
+ [_AddAndReturnMatrix()],
+]
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index c367ed25ad..021ef47383 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -160,20 +160,20 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
`block_depth = 1` means `A` is symmetric circulant. For example,
```
- A = |x y z y|
- |y x y z|
- |z y x y|
- |y z y x|
+ A = |w z y x|
+ |x w z y|
+ |y x w z|
+ |z y x w|
```
`block_depth = 2` means `A` is block symmetric circulant with symemtric
- circulant blocks. For example, with `X`, `Y`, `Z` symmetric circulant,
+ circulant blocks. For example, with `W`, `X`, `Y`, `Z` symmetric circulant,
```
- A = |X Y Z Y|
- |Y X Y Z|
- |Z Y X Y|
- |Y Z Y X|
+ A = |W Z Y X|
+ |X W Z Y|
+ |Y X W Z|
+ |Z Y X W|
```
`block_depth = 3` means `A` is block symmetric circulant with block
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 33e7a5533b..7c59232e40 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1088,9 +1088,6 @@ def floordiv(x, y, name=None):
`x // y` floor division in Python 3 and in Python 2.7 with
`from __future__ import division`.
- Note that for efficiency, `floordiv` uses C semantics for negative numbers
- (unlike Python and Numpy).
-
`x` and `y` must have the same type, and the result will have the same type
as well.
@@ -1100,7 +1097,7 @@ def floordiv(x, y, name=None):
name: A name for the operation (optional).
Returns:
- `x / y` rounded down (except possibly towards zero for negative integers).
+ `x / y` rounded down.
Raises:
TypeError: If the inputs are complex.
@@ -2906,22 +2903,24 @@ def tensordot(a, b, axes, name=None):
free_dims_static = None
shape_a = array_ops.shape(a)
rank_a = array_ops.rank(a)
- axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
- axes = cast(axes >= 0, dtypes.int32) * axes + cast(
- axes < 0, dtypes.int32) * (
- axes + rank_a)
- free, _ = array_ops.setdiff1d(range(rank_a), axes)
- free_dims = array_ops.gather(shape_a, free)
- axes_dims = array_ops.gather(shape_a, axes)
- prod_free_dims = reduce_prod(free_dims)
- prod_axes_dims = reduce_prod(axes_dims)
- perm = array_ops.concat([axes_dims, free_dims], 0)
- if flipped:
- perm = array_ops.concat([axes, free], 0)
- new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
- else:
- perm = array_ops.concat([free, axes], 0)
- new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
+ # TODO(b/115583659): Automate this.
+ with ops.device("/cpu:0"):
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ axes = cast(axes >= 0, dtypes.int32) * axes + cast(
+ axes < 0, dtypes.int32) * (
+ axes + rank_a)
+ free, _ = array_ops.setdiff1d(range(rank_a), axes)
+ free_dims = array_ops.gather(shape_a, free)
+ axes_dims = array_ops.gather(shape_a, axes)
+ prod_free_dims = reduce_prod(free_dims)
+ prod_axes_dims = reduce_prod(axes_dims)
+ perm = array_ops.concat([axes_dims, free_dims], 0)
+ if flipped:
+ perm = array_ops.concat([axes, free], 0)
+ new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
+ else:
+ perm = array_ops.concat([free, axes], 0)
+ new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
return reshaped_a, free_dims, free_dims_static
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index ef9afd9e8e..2526e6fee2 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -510,7 +510,7 @@ class _WithSpaceToBatch(object):
# Recover channel information for output shape if channels are not last.
if self.data_format is not None and self.data_format.startswith("NC"):
- if not result_converted.shape[1].value:
+ if not result_converted.shape[1].value and filter is not None:
output_shape = result_converted.shape.as_list()
output_shape[1] = filter.shape[-1]
result_converted.set_shape(output_shape)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index c0e66cb0b8..d403b0c61a 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1259,7 +1259,7 @@ class SparseTest(PForTest):
[3]) # [0, 2, 0]
pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(pfor, feed_dict={num_iters: 3})
def test_sparse_result_none_stacked(self):
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index f9cf16f6a4..628c6764cd 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -356,7 +356,7 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, jacobian_while)
def test_jacobian_unknown_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
y = math_ops.matmul(x, x, transpose_a=True)
jacobian_pfor = gradients.jacobian(y, x, use_pfor=True)
@@ -381,7 +381,7 @@ class GradientsTest(test.TestCase):
gradients.batch_jacobian(y, x, use_pfor=True)
def test_batch_jacobian_bad_unknown_shapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.concat([x, x], axis=0)
jacobian = gradients.batch_jacobian(y, x)
@@ -402,7 +402,7 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, batch_jacobian_while)
def test_batch_jacobian_unknown_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = x * x
batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True)
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 8224097ac4..bb8da3162a 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -1584,7 +1584,8 @@ def decode_csv(records,
record_defaults: A list of `Tensor` objects with specific types.
Acceptable types are `float32`, `float64`, `int32`, `int64`, `string`.
One tensor per column of the input record, with either a
- scalar default value for that column or empty if the column is required.
+ scalar default value for that column or an empty vector if the column is
+ required.
field_delim: An optional `string`. Defaults to `","`.
char delimiter to separate fields in a record.
use_quote_delim: An optional `bool`. Defaults to `True`.
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 5c00d929bf..5a3a5cc225 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -709,6 +709,10 @@ def _dynamic_rnn_loop(cell,
Raises:
ValueError: If the input depth cannot be inferred via shape inference
from the inputs.
+ ValueError: If time_step is not the same for all the elements in the
+ inputs.
+ ValueError: If batch_size is not the same for all the elements in the
+ inputs.
"""
state = initial_state
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c11c9ccaae..3e19183ff5 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -954,7 +954,7 @@ class LSTMCell(LayerRNNCell):
"""Run one step of LSTM.
Args:
- inputs: input Tensor, 2D, `[batch, num_units].
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 29fefbe3a5..b2c6937368 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -90,11 +90,6 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
Returns:
string `Tensor` of the same shape as `source` with specified replacements.
"""
- # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
- if not compat.forward_compatible(2018, 10, 10):
- return gen_string_ops.regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
if (isinstance(pattern, util_compat.bytes_or_text_types) and
isinstance(rewrite, util_compat.bytes_or_text_types)):
# When `pattern` and `rewrite` are static through the life of the op we can
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 94c7d88b5c..a404507627 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -234,6 +234,7 @@ def create_file_writer(logdir,
"""
if logdir is None:
return SummaryWriter(None, None)
+ logdir = str(logdir)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index 45de047894..5927bc2409 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -33,6 +33,7 @@ from tensorflow.python.lib.io.file_io import rename as Rename
from tensorflow.python.lib.io.file_io import stat as Stat
from tensorflow.python.lib.io.file_io import walk as Walk
# pylint: enable=unused-import
+from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -62,6 +63,7 @@ class FastGFile(_FileIO):
invocations in network filesystems).
"""
+ @deprecated(None, 'Use tf.gfile.GFile.')
def __init__(self, name, mode='r'):
super(FastGFile, self).__init__(name=name, mode=mode)
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index be8f425481..c411a58b70 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -188,7 +188,10 @@ limitations under the License.
"outputs of the operation)");
}
$1 = &temp;
- $1->resize(PyInt_AsLong($input), nullptr);
+ long sz = PyInt_AsLong($input);
+ if (sz > 0) {
+ $1->resize(PyInt_AsLong($input), nullptr);
+ }
}
// Create new Status object.
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index 5eeaf73a43..fe69f3beb0 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -91,10 +91,17 @@ with an asset of the same name, only the first version is retained.
#### Tags
Each meta graph added to the SavedModel must be annotated with user specified
-tags. The tags provide a means to identify the specific meta graph to load and
-restore, along with the shared set of variables and assets. These tags
-typically annotate a MetaGraph with its functionality (e.g. serving or
-training), and possibly hardware specific aspects such as GPU.
+tags, which reflect the meta graph capabilities or use-cases.
+More specifically, these tags typically annotate a meta graph with its
+functionality (e.g. serving or training), and possibly hardware specific aspects
+such as GPU.
+In the SavedModel, the meta graph def whose tag-set exactly matches those
+specified in the loader API, will be the one loaded by the loader.
+If no meta graph def is found matching the specified tags, an error is returned.
+For example, a loader with a requirement to serve on GPU hardware would be able
+to load only meta graph annotated with tags='serve,gpu' by specifying this set
+of tags in tensorflow::LoadSavedModel(...).
+
#### Usage
The typical usage of `builder` is as follows:
diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py
index 2936a279bd..14dec982a6 100644
--- a/tensorflow/python/summary/writer/event_file_writer.py
+++ b/tensorflow/python/summary/writer/event_file_writer.py
@@ -62,7 +62,7 @@ class EventFileWriter(object):
filename_suffix: A string. Every event file's name is suffixed with
`filename_suffix`.
"""
- self._logdir = logdir
+ self._logdir = str(logdir)
if not gfile.IsDirectory(self._logdir):
gfile.MakeDirs(self._logdir)
self._event_queue = six.moves.queue.Queue(max_queue)
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index c5289564fe..d8ba13d8d2 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -33,7 +33,6 @@ import numpy as np
from six import integer_types
from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
@@ -97,8 +96,7 @@ def _get_inputs_tensor_info_from_meta_graph_def(meta_graph_def,
Returns:
A dictionary that maps input tensor keys to TensorInfos.
"""
- return signature_def_utils.get_signature_def_by_key(meta_graph_def,
- signature_def_key).inputs
+ return meta_graph_def.signature_def[signature_def_key].inputs
def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
@@ -116,8 +114,7 @@ def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
Returns:
A dictionary that maps output tensor keys to TensorInfos.
"""
- return signature_def_utils.get_signature_def_by_key(meta_graph_def,
- signature_def_key).outputs
+ return meta_graph_def.signature_def[signature_def_key].outputs
def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py
index 2678016d24..a14ac895ac 100644
--- a/tensorflow/python/training/adadelta_test.py
+++ b/tensorflow/python/training/adadelta_test.py
@@ -155,7 +155,7 @@ class AdadeltaOptimizerTest(test.TestCase):
rtol=1e-5)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -168,7 +168,7 @@ class AdadeltaOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py
index c3a242a75e..00801be3b4 100644
--- a/tensorflow/python/training/adagrad_da_test.py
+++ b/tensorflow/python/training/adagrad_da_test.py
@@ -34,7 +34,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -81,7 +81,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
@@ -101,7 +101,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -133,7 +133,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAWithL1(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -165,7 +165,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAWithL1_L2(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index 4e634fff84..7caf01f64d 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -98,7 +98,7 @@ class AdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -117,7 +117,7 @@ class AdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -141,7 +141,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -172,7 +172,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -202,7 +202,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -226,7 +226,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -262,7 +262,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -295,7 +295,7 @@ class AdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 778c672077..48db6e3733 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -56,7 +56,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -122,7 +122,7 @@ class AdamOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -224,7 +224,7 @@ class AdamOptimizerTest(test.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -237,7 +237,7 @@ class AdamOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -274,7 +274,7 @@ class AdamOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index fe8a3e9062..2d469634e0 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1145,7 +1145,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1177,7 +1177,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=[self.summary_op, self.summary_op2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1205,7 +1205,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1240,7 +1240,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1388,7 +1388,7 @@ class ResourceSummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 8ef5048299..3a061bcb35 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -73,7 +73,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
# Collides with the default name of the checkpoint state file.
filepath = os.path.join(traindir, "checkpoint")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unused_a = variables.Variable(0.0) # So that Saver saves something.
variables.global_variables_initializer().run()
@@ -113,7 +113,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
filename = "snapshot"
filepath = os.path.join(traindir, filename)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a simple graph.
v0 = variables.Variable(0.0)
inc = v0.assign_add(1.0)
@@ -128,7 +128,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
inc.eval()
save.save(sess, filepath, global_step=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a new graph with different initialization.
v0 = variables.Variable(-1.0)
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
index 00611de862..dde8431497 100644
--- a/tensorflow/python/training/checkpoint_ops_test.py
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -43,7 +43,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
# 0., 1., ..., 79. reshaped into [5, 16].
initializer = init_ops.constant_initializer(
np.reshape(np.linspace(0.0, 79, 5 * 16), (5, 16)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope('some_scope'):
variable_scope.get_variable(name='embeddings', shape=[5, 16],
initializer=initializer)
@@ -114,7 +114,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
],
axis=1)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_output_layer_weight_initializer_linear(self):
@@ -150,7 +150,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -184,7 +184,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -222,7 +222,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -258,7 +258,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -292,7 +292,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
@@ -338,7 +338,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
@@ -376,7 +376,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 1aab16338a..61dcbdb2b8 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -84,7 +84,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -92,7 +92,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -105,7 +105,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -114,7 +114,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -148,7 +148,7 @@ class CheckpointsTest(test.TestCase):
def testInitialValueComesFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -178,7 +178,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -197,7 +197,7 @@ class CheckpointsTest(test.TestCase):
def testRestoreRunsOnSameDevice(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default():
@@ -213,7 +213,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -237,7 +237,7 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -260,7 +260,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -322,7 +322,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -367,7 +367,7 @@ class CheckpointsTest(test.TestCase):
def testNoAdditionalReadOpsForResourceVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index f06cbbfa15..c29e5db075 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import copy
import six
@@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
self._storage[index] = self._track_value(
element, name=self._name_element(index))
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
"""Determines the backing storage (overridden in subclasses)."""
return list(*args, **kwargs)
@@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence,
super(_ListWrapper, self).__init__(wrapped_list)
self._last_wrapped_list_snapshot = list(self._storage)
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_ListWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_ListWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_list):
"""Use the user's original list for storage."""
return wrapped_list
@@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
value, name=self._name_element(key))
for key, value in self._storage.items()})
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
return dict(*args, **kwargs)
@@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping):
super(_DictWrapper, self).__init__(wrapped_dict)
self._update_snapshot()
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_DictWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_DictWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_dict):
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
return wrapped_dict
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 4638917b4c..5597c7c772 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import os
import numpy
@@ -424,6 +425,104 @@ class MappingTests(test.TestCase):
new_dict.update(model.d)
self.assertEqual({1: 3}, new_dict)
+ def testListShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.copy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testListDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testDictShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.copy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testDictDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testShallowCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ shallow_copied = copy.copy(original)
+ self.assertIs(original_sub, shallow_copied.b["a"])
+ self.assertIsNot(original, shallow_copied)
+ self.assertEqual([[1.]], shallow_copied.a)
+ shallow_deps = util.list_objects(shallow_copied)
+ self.assertIn(shallow_copied.a, shallow_deps)
+ self.assertIn(shallow_copied.b, shallow_deps)
+ self.assertIn(shallow_copied.b["a"], shallow_deps)
+
+ def testDeepCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ deep_copied = copy.deepcopy(original)
+ self.assertIsNot(original, deep_copied)
+ self.assertIsNot(original_sub, deep_copied.b["a"])
+ self.assertEqual([[1.]], deep_copied.a)
+ self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
+ deps = util.list_objects(deep_copied)
+ self.assertIn(deep_copied.a, deps)
+ self.assertIn(deep_copied.b, deps)
+ self.assertIn(deep_copied.b["a"], deps)
+ self.assertNotIn(original_sub, deps)
+
def testConstructableFromSequence(self):
result = data_structures._DictWrapper([(1, 2), (3, 4)])
self.assertIsInstance(result, dict)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index e85f812ce2..a44c570fb9 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -165,7 +165,7 @@ class InterfaceTests(test.TestCase):
self.assertEqual([c], a.attribute["c"].layers)
checkpoint = util.Checkpoint(a=a)
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
- with self.test_session():
+ with self.cached_session():
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
@test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 0d32d21426..f8b5bd8501 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -384,7 +384,7 @@ class CheckpointingTests(test.TestCase):
saver = saver_lib.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 76ca5b45c9..09d6fe36d3 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -37,7 +37,7 @@ class FtrlOptimizerTest(test.TestCase):
def doTestFtrlwithoutRegularization(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -76,7 +76,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlwithoutRegularization2(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -105,7 +105,7 @@ class FtrlOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -121,7 +121,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -150,7 +150,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1_L2(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -186,7 +186,7 @@ class FtrlOptimizerTest(test.TestCase):
weights will tend to have smaller magnitudes with this parameter set.
"""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -335,7 +335,7 @@ class FtrlOptimizerTest(test.TestCase):
# FTRL-Proximal performs same updates as Adagrad or GradientDescent.
def testEquivAdagradwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -346,7 +346,7 @@ class FtrlOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
dtype)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
@@ -355,7 +355,7 @@ class FtrlOptimizerTest(test.TestCase):
def testEquivSparseAdagradwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -367,7 +367,7 @@ class FtrlOptimizerTest(test.TestCase):
dtype,
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
dtype,
@@ -378,7 +378,7 @@ class FtrlOptimizerTest(test.TestCase):
def testEquivSparseGradientDescentwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -390,7 +390,7 @@ class FtrlOptimizerTest(test.TestCase):
dtype,
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0),
dtype,
@@ -401,7 +401,7 @@ class FtrlOptimizerTest(test.TestCase):
def testEquivGradientDescentwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -412,7 +412,7 @@ class FtrlOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
dtype)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), dtype)
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index b304e92421..56d82a5b88 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -37,7 +37,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -60,7 +60,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasicResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -85,7 +85,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasicCallableParams(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -111,7 +111,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -137,7 +137,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -164,7 +164,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -186,7 +186,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testGradWrtRef(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
opt = gradient_descent.GradientDescentOptimizer(3.0)
values = [1.0, 3.0]
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
@@ -197,7 +197,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testWithGlobalStep(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
global_step = variables.Variable(0, trainable=False)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
@@ -220,7 +220,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 1b1e89cb26..a9b05dcc73 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -51,7 +51,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
for name in additional:
open(name, "w").write("Some contents")
filenames = list(set(filenames + additional))
- with self.test_session():
+ with self.cached_session():
star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
question = inp.match_filenames_once(
os.path.join(self.get_temp_dir(), "match_filenames.?"))
@@ -66,7 +66,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
class LimitEpochsTest(test_lib.TestCase):
def testNoLimit(self):
- with self.test_session():
+ with self.cached_session():
seven = constant_op.constant(7)
seven_forever = inp.limit_epochs(seven)
variables.local_variables_initializer().run()
@@ -74,7 +74,7 @@ class LimitEpochsTest(test_lib.TestCase):
self.assertEqual(7, seven_forever.eval())
def testLimit(self):
- with self.test_session():
+ with self.cached_session():
love_me = constant_op.constant("Love Me")
love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
variables.global_variables_initializer().run()
@@ -88,7 +88,7 @@ class LimitEpochsTest(test_lib.TestCase):
class InputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
input_tensor = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
@@ -111,7 +111,7 @@ class InputProducerTest(test_lib.TestCase):
thread.join()
def testNoShapeInference(self):
- with self.test_session():
+ with self.cached_session():
# Disable shape inference for the input.
input_value = [[1, 2, 3, 4],
[5, 6, 7, 8],
@@ -144,7 +144,7 @@ class InputProducerTest(test_lib.TestCase):
class StringInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
num_epochs = 3
queue = inp.string_input_producer(
@@ -166,7 +166,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"a", b"b", b"c"]
num_epochs = 600
queue = inp.string_input_producer(
@@ -206,7 +206,7 @@ class StringInputProducerTest(test_lib.TestCase):
def testNullStringPython(self):
# Graph-construction time check for empty string list:
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
_ = inp.string_input_producer([])
@@ -214,7 +214,7 @@ class StringInputProducerTest(test_lib.TestCase):
# Runtime check for empty string list. This is slightly oblique:
# The queue runner should die with an assertion error on the null
# input tensor, causing the dequeue to fail with an OutOfRangeError.
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
queue = inp.string_input_producer(
constant_op.constant(
@@ -230,7 +230,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(
strings, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -238,7 +238,7 @@ class StringInputProducerTest(test_lib.TestCase):
queue.queue_ref.op.node_def.attr["shared_name"])
def testConstructionRace(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(strings, shuffle=False)
coord = coordinator.Coordinator()
@@ -260,7 +260,7 @@ class StringInputProducerTest(test_lib.TestCase):
class RangeInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 3
range_size = 5
queue = inp.range_input_producer(
@@ -282,7 +282,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 200
range_size = 2
queue = inp.range_input_producer(
@@ -321,7 +321,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
range_size = 5
queue = inp.range_input_producer(
range_size, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -332,7 +332,7 @@ class RangeInputProducerTest(test_lib.TestCase):
class SliceInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 3
source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
source_ints = [2, 3, 5, 7]
@@ -356,7 +356,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 1200
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
@@ -400,7 +400,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
slices = inp.slice_input_producer(
@@ -440,7 +440,7 @@ class DictHelperTest(test_lib.TestCase):
class BatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -500,7 +500,7 @@ class BatchTest(test_lib.TestCase):
def testUint32DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -511,7 +511,7 @@ class BatchTest(test_lib.TestCase):
def testUint64DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -520,7 +520,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -550,7 +550,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -585,7 +585,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -625,7 +625,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -682,7 +682,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -737,7 +737,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -754,7 +754,7 @@ class BatchTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch([x], batch_size=2)
@@ -797,7 +797,7 @@ class BatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -934,7 +934,7 @@ class BatchTest(test_lib.TestCase):
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -952,7 +952,7 @@ class BatchTest(test_lib.TestCase):
class BatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, "a").
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1069,7 +1069,7 @@ class BatchJoinTest(test_lib.TestCase):
batch_size=8)
def DISABLED_testTwoThreadsDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1144,7 +1144,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, "a").
num_a = 70 + extra_elements
@@ -1243,7 +1243,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70 + extra_elements
@@ -1338,7 +1338,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1360,7 +1360,7 @@ class BatchJoinTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch_join([[x]], batch_size=2)
@@ -1371,7 +1371,7 @@ class BatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1511,7 +1511,7 @@ class BatchJoinTest(test_lib.TestCase):
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -1529,7 +1529,7 @@ class BatchJoinTest(test_lib.TestCase):
class ShuffleBatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1594,7 +1594,7 @@ class ShuffleBatchTest(test_lib.TestCase):
self._testOneThreadHelper(use_dict=True)
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1650,7 +1650,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1697,7 +1697,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1755,7 +1755,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1775,7 +1775,7 @@ class ShuffleBatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1906,7 +1906,7 @@ class ShuffleBatchTest(test_lib.TestCase):
class ShuffleBatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..24, "a").
num_a = 25
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2017,7 +2017,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
self._testTwoThreadsHelper(use_dict=True)
def testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..26, "a").
extra_elements = 2
num_a = 25 + extra_elements
@@ -2137,7 +2137,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
seed=223607)
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2162,7 +2162,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 4f3cf01822..5a9215730e 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -62,7 +62,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testVariables(self):
- with self.test_session():
+ with self.cached_session():
step = variables.Variable(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index f7e78071d8..8a21c39d32 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -123,7 +123,7 @@ class MomentumOptimizerTest(test.TestCase):
]), self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -162,7 +162,7 @@ class MomentumOptimizerTest(test.TestCase):
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -188,7 +188,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparseNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -282,7 +282,7 @@ class MomentumOptimizerTest(test.TestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -435,7 +435,7 @@ class MomentumOptimizerTest(test.TestCase):
return db_grad, db_out
def testLikeDistBeliefMom01(self):
- with self.test_session():
+ with self.cached_session():
db_grad, db_out = self._dbParamsMom01()
num_samples = len(db_grad)
var0 = variables.Variable([0.0] * num_samples)
@@ -449,7 +449,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
grads0 = ops.IndexedSlices(
@@ -518,7 +518,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index ff586b6c03..2d7799d66a 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -80,7 +80,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'],
sess.run(scaffold.ready_op))
self.assertItemsEqual([b'my_var'],
@@ -513,21 +513,21 @@ class WrappedSessionTest(test.TestCase):
"""_WrappedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertEquals(sess.graph, wrapped_sess.graph)
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertFalse(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_uses_check_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = StopAtNSession(sess, 3)
self.assertFalse(wrapped_sess.should_stop())
self.assertFalse(wrapped_sess.should_stop())
@@ -535,7 +535,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_delegates_to_wrapped_session(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess0 = StopAtNSession(sess, 4)
wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
self.assertFalse(wrapped_sess1.should_stop())
@@ -545,7 +545,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess1.should_stop())
def test_close_twice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@@ -553,7 +553,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
@@ -570,7 +570,7 @@ class CoordinatedSessionTest(test.TestCase):
"""_CoordinatedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
@@ -578,7 +578,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, coord_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -586,7 +586,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -594,7 +594,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_should_stop_on_coord_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -602,7 +602,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_dont_request_stop_on_exception_in_main_thread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -616,7 +616,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertFalse(coord_sess.should_stop())
def test_stop_threads_on_close_after_exception(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -646,7 +646,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_stop_threads_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = [
threading.Thread(
@@ -664,7 +664,7 @@ class CoordinatedSessionTest(test.TestCase):
def test_propagates_exception_trace(self):
assertion = control_flow_ops.Assert(False, ['This should fail.'])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator(clean_stop_exception_types=())
coord_sess = monitored_session._CoordinatedSession(sess, coord)
try:
@@ -810,7 +810,7 @@ class RecoverableSessionTest(test.TestCase):
return self._sess
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
recoverable_sess = monitored_session._RecoverableSession(
self._SessionReturner(sess))
@@ -818,7 +818,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
recoverable_sess = monitored_session._RecoverableSession(
@@ -826,7 +826,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_recovery(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
class StackSessionCreator(object):
@@ -872,7 +872,7 @@ class RecoverableSessionTest(test.TestCase):
recoverable_sess.run(v, feed_dict={c: -12})
def test_recovery_from_coordinator_exception(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -897,7 +897,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -926,7 +926,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -950,7 +950,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -980,7 +980,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -1014,7 +1014,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -1058,7 +1058,7 @@ class RecoverableSessionTest(test.TestCase):
return session
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1090,7 +1090,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1127,7 +1127,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -2047,7 +2047,7 @@ class MonitoredSessionTest(test.TestCase):
return value
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session)) as session:
session.run(variables.global_variables_initializer())
@@ -2110,7 +2110,7 @@ class MonitoredSessionTest(test.TestCase):
step_context.session.run(graph_side_effect)
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session),
hooks=[Hook(self)]) as session:
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index fdb8d795c3..93991d0e14 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import saver as saver_lib
class MovingAveragesTest(test.TestCase):
def testAssignMovingAverageWithoutZeroDebias(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([10.0, 11.0])
val = constant_op.constant([1.0, 2.0], dtypes.float32)
decay = 0.25
@@ -49,7 +49,7 @@ class MovingAveragesTest(test.TestCase):
var.eval())
def testAssignMovingAverage(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([0.0, 0.0])
val = constant_op.constant([1.0, 2.0], dtypes.float32)
decay = 0.25
@@ -86,7 +86,7 @@ class MovingAveragesTest(test.TestCase):
moving_averages.assign_moving_average(var, 0.0, 0.99)
def testWeightedMovingAverage(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
decay = 0.5
weight = array_ops.placeholder(dtypes.float32, [])
val = array_ops.placeholder(dtypes.float32, [])
@@ -187,53 +187,53 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertAllClose(expected, avg2.eval())
def testAverageVariablesNoNumUpdates_Scalar(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
def testAverageVariablesNoNumUpdates_Scalar_Debias(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
def testAverageVariablesNoNumUpdates_Vector(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNoNumUpdates_Vector_Debias(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNumUpdates_Scalar(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
def testAverageVariablesNumUpdates_Scalar_Debias(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(
0.25, num_updates=1, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
def testAverageVariablesNumUpdates_Vector(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
def testAverageVariablesNumUpdates_Vector_Debias(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(
0.25, num_updates=1, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
def testAverageVariablesWithControlDeps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(0, name="v0")
add_to_v0 = v0.assign_add(1)
v1 = variables.Variable([10.0], name="v1")
@@ -276,7 +276,7 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
def averageVariablesNamesHelper(self, zero_debias):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
# Add a non-trainable variable.
@@ -320,7 +320,7 @@ class ExponentialMovingAverageTest(test.TestCase):
def averageVariablesNamesRespectScopeHelper(self, zero_debias):
# See discussion on #2740.
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("scope1"):
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
@@ -367,7 +367,7 @@ class ExponentialMovingAverageTest(test.TestCase):
self.averageVariablesNamesRespectScopeHelper(zero_debias=False)
def testSubsetAverageVariablesNames(self):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
# Add a non-trainable variable.
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index dfe9176bea..7a7d01d50e 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -64,7 +64,7 @@ class OptimizerTest(test.TestCase):
def testAggregationMethod(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -89,7 +89,7 @@ class OptimizerTest(test.TestCase):
def testPrecomputedGradient(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -231,7 +231,7 @@ class OptimizerTest(test.TestCase):
sgd_op.apply_gradients(grads_and_vars)
def testTrainOp(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([3.0, 4.0])
cost = 5 * var0 + 3 * var1
@@ -244,7 +244,7 @@ class OptimizerTest(test.TestCase):
def testConstraint(self):
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0],
constraint=constraint_01)
var1 = variables.Variable([3.0, 4.0],
diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py
index 430c16b351..74e06a5e2e 100644
--- a/tensorflow/python/training/proximal_adagrad_test.py
+++ b/tensorflow/python/training/proximal_adagrad_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import proximal_adagrad
class ProximalAdagradOptimizerTest(test.TestCase):
def doTestProximalAdagradwithoutRegularization(self, use_resource=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([0.0, 0.0])
var1 = variables.Variable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -71,7 +71,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.doTestProximalAdagradwithoutRegularization(use_resource=True)
def testProximalAdagradwithoutRegularization2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -98,7 +98,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -114,7 +114,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
[[0, 1]], var0.eval(), atol=0.01)
def testProximalAdagradWithL1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -140,7 +140,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.assertAllClose(np.array([2.959304, 1.029232]), v1_val)
def testProximalAdagradWithL1_L2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -206,7 +206,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
return v0_val, v1_val
def testEquivAdagradwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -214,7 +214,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1))
@@ -223,7 +223,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.assertAllClose(val1, val3)
def testEquivSparseAdagradwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -232,7 +232,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1),
diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py
index 4e4812fe60..f77f68b234 100644
--- a/tensorflow/python/training/proximal_gradient_descent_test.py
+++ b/tensorflow/python/training/proximal_gradient_descent_test.py
@@ -36,7 +36,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
def doTestProximalGradientDescentwithoutRegularization(
self, use_resource=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
@@ -69,7 +69,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.doTestProximalGradientDescentwithoutRegularization(use_resource=True)
def testProximalGradientDescentwithoutRegularization2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -94,7 +94,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -111,7 +111,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
[[-111, -138]], var0.eval(), atol=0.01)
def testProximalGradientDescentWithL1_L2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -174,7 +174,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
return v0_val, v1_val
def testEquivSparseGradientDescentwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
@@ -182,7 +182,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True)
@@ -190,14 +190,14 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.assertAllClose(val1, val3)
def testEquivGradientDescentwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0))
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 900f9706ac..9b9e28af2b 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -41,7 +41,7 @@ _MockOp = collections.namedtuple("MockOp", ["name"])
class QueueRunnerTest(test.TestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -61,7 +61,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(3, var.eval())
def testTwoOps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var0 = variables.Variable(zero64)
@@ -84,7 +84,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(30, var1.eval())
def testExceptionsCaptured(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"),
_MockOp("so fail")])
@@ -100,7 +100,7 @@ class QueueRunnerTest(test.TestCase):
self.assertTrue("Operation not in the graph" in str(exceptions[1]))
def testRealDequeueEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
enqueue0 = q0.enqueue((10.0,))
close0 = q0.close()
@@ -128,7 +128,7 @@ class QueueRunnerTest(test.TestCase):
dequeue1.eval()
def testRespectCoordShouldStop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -152,7 +152,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(0, var.eval())
def testRequestStopOnException(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
coord = coordinator.Coordinator()
@@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase):
coord.join()
def testGracePeriod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The enqueue will quickly block.
queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
enqueue = queue.enqueue((10.0,))
@@ -181,7 +181,7 @@ class QueueRunnerTest(test.TestCase):
coord.join(stop_grace_period_secs=1.0)
def testMultipleSessions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -196,7 +196,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(len(threads), len(other_threads))
def testIgnoreMultiStarts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -212,7 +212,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual([], new_threads)
def testThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -256,7 +256,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_op.run()
threads = queue_runner_impl.start_queue_runners(sess)
for t in threads:
@@ -273,7 +273,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session():
+ with self.cached_session():
init_op.run()
with self.assertRaisesRegexp(TypeError, "tf.Session"):
queue_runner_impl.start_queue_runners("NotASession")
@@ -286,7 +286,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session():
+ with self.cached_session():
init_op.run()
threads = queue_runner_impl.start_queue_runners(
monitored_session.MonitoredSession())
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 6043327384..4f5f96e2b4 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -165,7 +165,7 @@ class RMSPropOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -187,7 +187,7 @@ class RMSPropOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariableCentered(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f5b2a22327..0ac84813c8 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -324,7 +324,7 @@ class SaverTest(test.TestCase):
save_relative_paths=True)
init_all_op = [variables.global_variables_initializer(), v2_init]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -349,7 +349,7 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -373,7 +373,7 @@ class SaverTest(test.TestCase):
v0 = variables.Variable(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = sess.graph.get_tensor_by_name(
save.saver_def.filename_tensor_name)
self.assertEqual(sess.run(tensor), filename)
@@ -381,7 +381,7 @@ class SaverTest(test.TestCase):
def testInvalidPath(self):
v0 = variables.Variable(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
with self.assertRaisesRegexp(
ValueError, "The passed save_path is not a valid checkpoint:"):
@@ -390,7 +390,7 @@ class SaverTest(test.TestCase):
def testInt64(self):
save_path = os.path.join(self.get_temp_dir(), "int64")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
v = variables.Variable(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
@@ -401,7 +401,7 @@ class SaverTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
@@ -559,12 +559,12 @@ class SaverTest(test.TestCase):
def testAllowEmpty(self):
save_path = os.path.join(self.get_temp_dir(), "allow_empty")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = constant_op.constant(1)
save = saver_module.Saver(allow_empty=True)
val = save.save(sess, save_path)
self.assertIsNone(val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save = saver_module.Saver(allow_empty=True)
save.restore(sess, save_path)
@@ -740,7 +740,7 @@ class SaverTest(test.TestCase):
# save succeeds or fails is implementation dependent. Therefore we allow
# both cases.
try:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -751,7 +751,7 @@ class SaverTest(test.TestCase):
# Save the graph.
save.save(sess, save_path)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
@@ -775,7 +775,7 @@ class SaverTest(test.TestCase):
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -983,7 +983,7 @@ class SaveRestoreShardedTest(test.TestCase):
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
@@ -1209,7 +1209,7 @@ class MaxToKeepTest(test.TestCase):
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
variables.global_variables_initializer().run()
@@ -1447,7 +1447,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("no_max_to_keep")
save_dir2 = self._get_test_dir("max_to_keep_0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
variables.global_variables_initializer().run()
@@ -1474,7 +1474,7 @@ class MaxToKeepTest(test.TestCase):
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
save = saver_module.Saver({"v": v})
variables.global_variables_initializer().run()
@@ -1497,7 +1497,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
def testNonSharded(self, mock_time):
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.variable([10.0], name="v")
# Run the initializer NOW to avoid the 0.5s overhead of the first Run()
# call, which throws the test timing off in fastbuild mode.
@@ -1630,7 +1630,7 @@ class MetaGraphTest(test.TestCase):
def testAddCollectionDef(self):
test_dir = self._get_test_dir("good_collection")
filename = os.path.join(test_dir, "metafile")
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(1.0, name="v0")
control_flow_ops.cond(
@@ -1685,7 +1685,7 @@ class MetaGraphTest(test.TestCase):
self, meta_graph_def, new_meta_graph_def)
def testAddCollectionDefFails(self):
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(10.0, name="v0")
# Creates a saver.
@@ -1870,7 +1870,7 @@ class MetaGraphTest(test.TestCase):
def testSliceVariable(self):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([20.0], name="v1")
v2 = variables.Variable([20.0], name="v2")
v2._set_save_slice_info(
@@ -1946,7 +1946,7 @@ class MetaGraphTest(test.TestCase):
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initializes all the variables.
sess.run(init_all_op)
# Runs to logit.
@@ -2120,7 +2120,7 @@ class MetaGraphTest(test.TestCase):
# pylint: enable=g-long-lambda
def testStrippedOpListDef(self):
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(0.0)
var = variables.Variable(10.0)
@@ -2160,7 +2160,7 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
- with self.test_session():
+ with self.cached_session():
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -2397,7 +2397,7 @@ class CheckpointReaderTest(test.TestCase):
}, write_version=self._WRITE_VERSION)
save_path = os.path.join(self.get_temp_dir(),
"ckpt_for_debug_string" + str(self._WRITE_VERSION))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_all_op)
# Saves a checkpoint.
save.save(sess, save_path)
@@ -2853,7 +2853,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
@@ -2867,7 +2867,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
@@ -2900,7 +2900,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.assertEqual(1, v.eval_count)
saver.restore(sess, save_path)
@@ -2957,7 +2957,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
b = resource_variable_ops.ResourceVariable(1., name="b")
a_saver = saver_module.Saver([a])
b_saver = saver_module.Saver([b])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index d7e6dac95b..f1d18f7704 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -98,7 +98,7 @@ class SessionManagerTest(test.TestCase):
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
v = variables.Variable([6.0, 7.0, 8.0], name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
@@ -236,7 +236,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -294,7 +294,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -326,7 +326,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables(),
@@ -362,7 +362,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -467,7 +467,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -519,7 +519,7 @@ class SessionManagerTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x_res")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -566,7 +566,7 @@ class SessionManagerTest(test.TestCase):
with ops.Graph().as_default():
i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
v = variables.Variable(array_ops.identity(i), name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
@@ -585,7 +585,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -602,7 +602,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -619,7 +619,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -640,7 +640,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -714,7 +714,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
v = variables.Variable([6.0, 7.0, 8.0], name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
@@ -769,7 +769,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
v = variables.Variable(2, name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py
index 08a3c8dc53..6d6364169f 100644
--- a/tensorflow/python/training/slot_creator_test.py
+++ b/tensorflow/python/training/slot_creator_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import slot_creator
class SlotCreatorTest(test.TestCase):
def testCreateSlotFromVariable(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
@@ -44,7 +44,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([1.0, 2.5], slot.eval())
def testCreateSlotFromTensor(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant([1.0, 2.5], name="const")
slot = slot_creator.create_slot(v, v * 2, name="slot")
@@ -56,7 +56,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([2.0, 5.0], slot.eval())
def testCreateZerosSlotFromVariable(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
with ops.control_dependencies(None):
slot = slot_creator.create_zeros_slot(
@@ -70,7 +70,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromDynamicShapedVariable(self):
- with self.test_session():
+ with self.cached_session():
dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
dyn_shape = array_ops.placeholder_with_default(dyn_shape,
shape=[None])
@@ -91,7 +91,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromTensor(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant([1.0, 2.5], name="const")
with ops.control_dependencies(None):
slot = slot_creator.create_zeros_slot(v, name="slot")
@@ -104,7 +104,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromDynamicShapedTensor(self):
- with self.test_session():
+ with self.cached_session():
v = random_ops.random_uniform([2], dtype=dtypes.float64)
v = array_ops.placeholder_with_default(v, shape=[None], name="const")
with ops.control_dependencies(None):
@@ -120,7 +120,7 @@ class SlotCreatorTest(test.TestCase):
def testCreateSlotFromVariableRespectsScope(self):
# See discussion on #2740.
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("scope"):
v = variables.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 71ed88093a..caf6eba3e0 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -795,7 +795,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
@@ -859,14 +859,14 @@ class SupervisorTest(test.TestCase):
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
def testNoQueueRunners(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
self.assertEqual(0, len(sv.start_queue_runners(sess)))
sv.stop()
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 3ee0f6aaa2..6c860cd452 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -1133,7 +1133,7 @@ class WarmStartingUtilTest(test.TestCase):
# Unused variable names raises ValueError.
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variable_scope.get_variable(
"x",
shape=[4, 1],
diff --git a/tensorflow/python/util/memory.py b/tensorflow/python/util/memory.py
new file mode 100644
index 0000000000..e78f6d509a
--- /dev/null
+++ b/tensorflow/python/util/memory.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions related to Python memory management."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# TODO(b/115366440): Delete this function when a custom OrderedDict is added
+def dismantle_ordered_dict(ordered_dict):
+ """Remove reference cycle in OrderedDict `ordered_dict`.
+
+ Helpful for making sure the garbage collector doesn't need to run after
+ using an OrderedDict.
+
+ Args:
+ ordered_dict: A `OrderedDict` object to destroy. This object is unusable
+ after this function runs.
+ """
+ # OrderedDict, makes a simple reference loop
+ # and hides it in an __attribute in some Python versions. We don't need to
+ # throw an error if we can't find it, but if we do find it we can break the
+ # loop to avoid creating work for the garbage collector.
+ problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access
+ if problematic_cycle:
+ try:
+ del problematic_cycle[0][:]
+ except TypeError:
+ # This is probably not one of the problematic Python versions. Continue
+ # with the rest of our cleanup.
+ pass
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2369eb610e..ef503137d1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -461,7 +461,7 @@ class NestTest(parameterized.TestCase, test.TestCase):
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_np = sess.run(output, feed_dict=feed_dict)
self.assertAllClose(output_np[0],
feed_dict[inp_a][0] + feed_dict[inp_b][0])
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 778121e15b..967c872c2a 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -325,6 +325,11 @@ def isfunction(object): # pylint: disable=redefined-builtin
return _inspect.isfunction(tf_decorator.unwrap(object)[1])
+def isgenerator(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.isgenerator."""
+ return _inspect.isgenerator(tf_decorator.unwrap(object)[1])
+
+
def ismethod(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.ismethod."""
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 16fa1f547d..fedbe1dff6 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -106,7 +106,7 @@ class TfShouldUseTest(test.TestCase):
def return_const(value):
return constant_op.constant(value, name='blah3')
with reroute_error() as (error, _):
- with self.test_session():
+ with self.cached_session():
return_const(0.0)
# Creating another op and executing it does not mark the
# unused op as being "used".
@@ -124,7 +124,8 @@ class TfShouldUseTest(test.TestCase):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah3')
- with self.test_session():
+
+ with self.cached_session():
return_const(0.0).mark_used()
if __name__ == '__main__':
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e62a..10bf006787 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <atomic>
#include <utility>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
CheckPlatformKindIsValid(platform_kind);
}
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+ int64 value;
+ SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+ 0, &value));
+ return value * (1ll << 20);
+}
+
StreamExecutor::StreamExecutor(
const Platform *platform,
std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor(
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
- tracing_enabled_(false) {
+ tracing_enabled_(false),
+ mem_alloc_bytes_(0),
+ memory_limit_bytes_(GetMemoryLimitBytes()) {
if (port::Lowercase(platform_->Name()) == "cuda") {
platform_kind_ = PlatformKind::kCuda;
} else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
}
void *StreamExecutor::Allocate(uint64 size) {
+ if (memory_limit_bytes_ > 0 &&
+ mem_alloc_bytes_ + size > memory_limit_bytes_) {
+ LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+ << device_ordinal_
+ << " within provided limit. [used=" << mem_alloc_bytes_
+ << ", limit=" << memory_limit_bytes_ << "]";
+ return nullptr;
+ }
void *buf = implementation_->Allocate(size);
VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
<< buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
+ mem_alloc_bytes_ += bytes;
}
}
@@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
} else {
+ mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
mem_allocs_.erase(opaque);
}
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 437f298616..d04025b681 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -699,6 +699,13 @@ class StreamExecutor {
// The set of TraceListeners registered for this StreamExecutor.
std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+ // Allocated memory in bytes.
+ int64 mem_alloc_bytes_;
+
+ // Memory limit in bytes. Value less or equal to 0 indicates there is no
+ // limit.
+ int64 memory_limit_bytes_;
+
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
index 24a58fb118..f06e798953 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "input_layer"
- argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+ argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
}
member_method {
name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index d843194ef0..0869de0243 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index b8e9baca71..20f39fae1e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 472b9818df..4011719317 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 937516eff1..8a12ac1ad8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000000..e7e7d2839b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+ }
+ member_method {
+ name: "get"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_running"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "start"
+ argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+ }
+ member_method {
+ name: "stop"
+ argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index 4d7a1519ce..81b91d2780 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "OrderedEnqueuer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Progbar"
mtype: "<type \'type\'>"
}
@@ -45,6 +49,10 @@ tf_module {
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_source_inputs"
+ argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "multi_gpu_model"
argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
index 24a58fb118..f06e798953 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "input_layer"
- argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+ argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
}
member_method {
name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index d843194ef0..0869de0243 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index b8e9baca71..20f39fae1e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 472b9818df..4011719317 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 937516eff1..8a12ac1ad8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000000..e7e7d2839b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+ }
+ member_method {
+ name: "get"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_running"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "start"
+ argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+ }
+ member_method {
+ name: "stop"
+ argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index 4d7a1519ce..81b91d2780 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "OrderedEnqueuer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Progbar"
mtype: "<type \'type\'>"
}
@@ -45,6 +49,10 @@ tf_module {
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_source_inputs"
+ argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "multi_gpu_model"
argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 99bed5714f..d06c7f2d49 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -174,7 +174,7 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = diff_message
else:
# Do not truncate diff
- self.maxDiffs = None # pylint: disable=invalid-name
+ self.maxDiff = None # pylint: disable=invalid-name
# Now we can run an actual proto diff.
try:
self.assertProtoEquals(expected_dict[key], actual_dict[key])
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
new file mode 100644
index 0000000000..a30858db82
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -0,0 +1,83 @@
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 \
+# --tag "gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04" .
+# $ docker push gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04
+#
+# TODO(klimek): Include clang in this image so we can also target clang
+# builds.
+
+FROM ubuntu:14.04
+LABEL maintainer="Manuel Klimek <klimek@google.com>"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates apt-transport-https gnupg-curl && \
+ rm -rf /var/lib/apt/lists/* && \
+ NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \
+ NVIDIA_GPGKEY_FPR=ae09fe4bbd223a84b2ccfce3f60f4b3d7fa2af80 && \
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/7fa2af80.pub && \
+ apt-key adv --export --no-emit-version -a $NVIDIA_GPGKEY_FPR | tail -n +2 > cudasign.pub && \
+ echo "$NVIDIA_GPGKEY_SUM cudasign.pub" | sha256sum -c --strict - && rm cudasign.pub && \
+ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
+ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list
+
+ENV CUDA_VERSION 9.0.176
+ENV CUDA_PKG_VERSION 9-0=$CUDA_VERSION-1
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
+ENV NCCL_VERSION 2.2.13
+ENV CUDNN_VERSION 7.2.1.38
+
+# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
+# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
+# correct way to pass the path to bfd-ld is to pass
+# -Wl,-rpath-link=/usr/local/cuda/lib64/stubs to all binaries transitively
+# depending on libcuda. Optimally, builds targeting cuda would do that
+# internally.
+ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64/stubs
+
+LABEL com.nvidia.volumes.needed="nvidia_driver"
+LABEL com.nvidia.cuda.version="${CUDA_VERSION}"
+LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ cuda-cudart-$CUDA_PKG_VERSION \
+ cuda-libraries-$CUDA_PKG_VERSION \
+ cuda-cublas-9-0=9.0.176.4-1 \
+ libnccl2=$NCCL_VERSION-1+cuda9.0 \
+ cuda-libraries-dev-$CUDA_PKG_VERSION \
+ cuda-nvml-dev-$CUDA_PKG_VERSION \
+ cuda-minimal-build-$CUDA_PKG_VERSION \
+ cuda-command-line-tools-$CUDA_PKG_VERSION \
+ cuda-core-9-0=9.0.176.3-1 \
+ cuda-cublas-dev-9-0=9.0.176.4-1 \
+ libnccl-dev=$NCCL_VERSION-1+cuda9.0 \
+ libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 \
+ libcudnn7=$CUDNN_VERSION-1+cuda9.0 && \
+ ln -s cuda-9.0 /usr/local/cuda && \
+ apt-mark hold libnccl2 && \
+ apt-mark hold libcudnn7 libcudnn7-dev && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
+ echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
+
+# TODO(b/110903506): Provide a link to the SONAME of libcuda.so.
+# https://github.com/NVIDIA/nvidia-docker/issues/775
+RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
+# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
+# libnccl is resolved, delete this block.
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
+ && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_golang.sh
+
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu
deleted file mode 100644
index 08dc026328..0000000000
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu
+++ /dev/null
@@ -1,43 +0,0 @@
-# To push a new version, run:
-# $ docker build -f Dockerfile.rbe.gcc.gpu \
-# --tag "gcr.io/asci-toolchain/nosla-nvidia-gcc" .
-# $ docker push gcr.io/asci-toolchain/nosla-nvidia-gcc
-FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
-
-LABEL maintainer="Manuel Klimek <klimek@google.com>"
-
-# TODO(b/110903506): Fix the nvidia docker image by providing a link to the
-# SONAME of libcuda.so. Alternatively, consider using gold or lld which do not
-# run into the same problem - that will only work once the tensorflow build does
-# not link to libcuda from generators anymore.
-# https://github.com/NVIDIA/nvidia-docker/issues/775
-RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
-
-# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
-# libnccl is resolved, delete this block.
-RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
- && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
-
-# TODO(b/110903506): Fix tensorflow to not require the use of LD_LIBRARY_PATH.
-# The stubs/libcuda.so is not meant to used at runtime. The correct way to
-# pass the path to bfd-ld is to pass -Wl,-rpath-link=/usr/local/cuda/lib64/stubs
-# to all binaries transitively depending on libcuda. Optimally the tensorflow
-# build would do that internally.
-ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs
-
-# Copy and run the install scripts.
-COPY install/*.sh /install/
-ARG DEBIAN_FRONTEND=noninteractive
-RUN /install/install_bootstrap_deb_packages.sh
-RUN add-apt-repository -y ppa:openjdk-r/ppa && \
- add-apt-repository -y ppa:george-edison55/cmake-3.x
-RUN /install/install_deb_packages.sh
-RUN /install/install_pip_packages.sh
-RUN /install/install_golang.sh
-
-# Install nccl2.
-RUN apt-get update && apt-get install -y \
- libnccl2 \
- libnccl-dev \
- && rm -rf /var/lib/apt-lists/*
-
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index bbaf59c69a..4b762bf258 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -76,7 +76,7 @@ ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow
# Do not run tests with "no_pip" tag. If running GPU tests, also do not run
# tests with no_pip_gpu tag.
-PIP_TEST_FILTER_TAG="-no_pip,-no_oss"
+PIP_TEST_FILTER_TAG="-no_pip,-no_oss,-benchmark-test"
if [[ ${IS_OSS_SERIAL} == "1" ]]; then
PIP_TEST_FILTER_TAG="$(echo "${PIP_TEST_FILTER_TAG}" | sed s/-no_oss//)"
PIP_TEST_FILTER_TAG="${PIP_TEST_FILTER_TAG},oss_serial"
@@ -85,7 +85,7 @@ else
fi
if [[ ${IS_GPU} == "1" ]]; then
- PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
+ PIP_TEST_FILTER_TAG="-no_gpu,-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
fi
if [[ ${IS_MAC} == "1" ]]; then
PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index c8472102cb..cc09784c1d 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -127,17 +127,19 @@ NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone"
DO_DOCKER=1
-BAZEL_CMD="bazel test"
-BAZEL_BUILD_ONLY_CMD="bazel build"
-BAZEL_CLEAN_CMD="bazel clean"
-# Default flags:
+# Helpful flags:
# --test_summary=detailed: Tell us more about which targets are being built
# --keep_going: Don't stop at the first failure; tell us all the failures
# --build_tests_only: Don't build targets depended on by tests if the test is
# disabled. Also saves some compilation time. Otherwise,
# tries to build everything.
-DEFAULT_BAZEL_CONFIGS="--test_summary=detailed --build_tests_only --keep_going"
+BAZEL_TEST_FLAGS="--test_summary=detailed --build_tests_only --keep_going"
+BAZEL_BUILD_FLAGS="--keep_going"
+
+BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS}"
+BAZEL_BUILD_ONLY_CMD="bazel build ${BAZEL_BUILD_FLAGS}"
+BAZEL_CLEAN_CMD="bazel clean"
PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh"
PIP_TEST_TUTORIALS_FLAG="--test_tutorials"
@@ -393,7 +395,7 @@ fi
EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false"
if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] &&
- [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
+ [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD}
fi
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 75da9bb835..03a2a07fb1 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -16,29 +16,25 @@
#
#
# A script to run multiple GPU tests in parallel controlled with an environment
-# variable. This script will assume that when it runs, one of the locks are
-# already released. So the program calling this script is expected to make sure
-# that only $TF_GPU_COUNT processes are running at any gien time.
+# variable.
#
# Required environment variables:
-# TF_GPU_COUNT = Number of GPUs available. This HAS TO BE IN SYNC with the
-# value of --local_test_jobs flag for bazel.
+# TF_GPU_COUNT = Number of GPUs available.
-BASH_VER_MAJOR=$(echo ${BASH_VERSION} | cut -d '.' -f 1)
-BASH_VER_MINOR=$(echo ${BASH_VERSION} | cut -d '.' -f 2)
-
-if [[ ${BASH_VER_MAJOR} -lt 4 ]]; then
- echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
- exit 1
-elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then
- echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
- exit 1
-fi
-
-function is_absolute {
- [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
-}
+TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
+# We want to allow running one of the following configs:
+# - 4 tests per GPU on k80
+# - 8 tests per GPU on p100
+# p100 has minimum 12G memory. Therefore, we should limit each test to 1.5G.
+# To leave some room in case we want to run more tests in parallel in the
+# future and to use a rounder number, we set it to 1G.
+export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024
+# *******************************************************************
+# This section of the script is needed to
+# make things work on windows under msys.
+# *******************************************************************
RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST"
function rlocation() {
if is_absolute "$1" ; then
@@ -55,29 +51,32 @@ function rlocation() {
TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})"
shift
+# *******************************************************************
-# Make sure /var/lock exists, this may not be true under MSYS
mkdir -p /var/lock
-
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-
-for i in `seq 0 $((TF_GPU_COUNT-1))`; do
- exec {lock_fd}>/var/lock/gpulock$i || exit 1
- if flock -n "$lock_fd";
- then
- (
- # This export only works within the brackets, so it is isolated to one
- # single command.
- export CUDA_VISIBLE_DEVICES=$i
- echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
- "$TEST_BINARY" $@
- )
- return_code=$?
- flock -u "$lock_fd"
- exit $return_code
- fi
+# Try to acquire any of the TF_GPU_COUNT * TF_TESTS_PER_GPU
+# slots to run a test at.
+#
+# Prefer to allocate 1 test per GPU over 4 tests on 1 GPU.
+# So, we iterate over TF_TESTS_PER_GPU first.
+for j in `seq 0 $((TF_TESTS_PER_GPU-1))`; do
+ for i in `seq 0 $((TF_GPU_COUNT-1))`; do
+ exec {lock_fd}>/var/lock/gpulock${i}_${j} || exit 1
+ if flock -n "$lock_fd";
+ then
+ (
+ # This export only works within the brackets, so it is isolated to one
+ # single command.
+ export CUDA_VISIBLE_DEVICES=$i
+ echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
+ "$TEST_BINARY" $@
+ )
+ return_code=$?
+ flock -u "$lock_fd"
+ exit $return_code
+ fi
+ done
done
echo "Cannot find a free GPU to run the test $* on, exiting with failure..."
exit 1
-
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
index 2a9f295188..7be5f454ec 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
@@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution
# in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads
# caused by executing multiple tests concurrently.
-bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
+bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
--config=mkl --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 01f37d8768..35a74c9664 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -35,7 +35,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
"""
def testArgRenames(self):
- with self.test_session():
+ with self.cached_session():
a = [[1., 2., 3.], [4., 5., 6.]]
b = [[True, False, False], [False, True, True]]
@@ -98,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
[[[1, 2]], [[3, 4]]])
def testArgMinMax(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(),
[0, 2])
@@ -113,7 +113,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
[1, 0, 0])
def testExpandAndSqueeze(self):
- with self.test_session():
+ with self.cached_session():
# TODO(aselle): sparse_split, sparse_reduce_sum,
# sparse_reduce_sum_sparse, reduce_join
@@ -140,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
a)
def testArithmeticRenames(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]])
vals = s.run(stuff)
self.assertAllEqual(vals,
@@ -164,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
# ]
def testBatchAndSvd(self):
- with self.test_session():
+ with self.cached_session():
mat = [[1., 2.], [2., 3.]]
batched_mat = tf.expand_dims(mat, [0])
result = tf.matmul(mat, mat).eval()
@@ -176,7 +176,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def testCrossEntropy(self):
# TODO(aselle): Test sparse_softmax_...
- with self.test_session():
+ with self.cached_session():
labels = [.8, .5, .2, .1]
logits = [.9, .1, .3, .1]
self.assertAllEqual(
@@ -191,7 +191,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
labels=labels, logits=logits).eval())
def testVariables(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
# make some variables
_ = [tf.Variable([1, 2, 3], dtype=tf.float32),
@@ -201,7 +201,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_ = [v.name for v in tf.local_variables()]
def testSummaries(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
var = tf.Variable([1, 2, 3], dtype=tf.float32)
s.run(tf.initialize_all_variables())
x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256))
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
index a49035a1a0..e5ca8d3e2e 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
@@ -26,7 +26,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
"""Test various APIs that have been changed in 2.0."""
def testRenames(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1.04719755, tf.acos(0.5).eval())
self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
index d64db35afb..5996573cf1 100644
--- a/tensorflow/tools/dockerfiles/README.md
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -34,13 +34,13 @@ documentation](https://docs.docker.com/engine/reference/run/).
# User permissions (-u) are required if you use (-v).
# CPU-based images
-$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
# GPU-based images (set up nvidia-docker2 first)
-$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
# Images with Jupyter run on port 8888, and needs a volume for notebooks
-$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf
+$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(pwd):/notebooks -it tf
```
These images do not come with the TensorFlow source code -- but the development
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index a6159fa692..83b4bf8128 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -1479,7 +1479,7 @@ class ParserConfig(object):
self.base_dir = base_dir
self.defined_in_prefix = 'tensorflow/'
self.code_url_prefix = (
- 'https://www.tensorflow.org/code/tensorflow/') # pylint: disable=line-too-long
+ '/code/stable/tensorflow/') # pylint: disable=line-too-long
def py_name_to_object(self, full_name):
"""Return the Python object for a Python symbol name."""
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 91c5cd094c..50515b04a9 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -60,16 +60,6 @@ COMMON_PIP_DEPS = [
":included_headers",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/autograph:autograph",
- "//tensorflow/contrib/autograph/converters:converters",
- "//tensorflow/contrib/autograph/core:core",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/impl:impl",
- "//tensorflow/contrib/autograph/lang:lang",
- "//tensorflow/contrib/autograph/operators:operators",
- "//tensorflow/contrib/autograph/pyct:pyct",
- "//tensorflow/contrib/autograph/pyct/testing:testing",
- "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
- "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
@@ -102,6 +92,16 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/timeseries:timeseries_pip",
"//tensorflow/contrib/tpu",
"//tensorflow/examples/tutorials/mnist:package",
+ # "//tensorflow/python/autograph/converters:converters",
+ # "//tensorflow/python/autograph/core:core",
+ "//tensorflow/python/autograph/core:test_lib",
+ # "//tensorflow/python/autograph/impl:impl",
+ # "//tensorflow/python/autograph/lang:lang",
+ # "//tensorflow/python/autograph/operators:operators",
+ # "//tensorflow/python/autograph/pyct:pyct",
+ # "//tensorflow/python/autograph/pyct/testing:testing",
+ # "//tensorflow/python/autograph/pyct/static_analysis:static_analysis",
+ "//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python:cond_v2",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:meta_graph_testdata",
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 666ea75d46..c62271c5cb 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -43,8 +43,7 @@ function cp_external() {
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
function is_windows() {
- # On windows, the shell script is actually running in msys
- if [[ "${PLATFORM}" =~ (mingw64|msys)_nt* ]]; then
+ if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then
true
else
false
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9a82c724b7..4c5aedba36 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
],
- sha256 = "f4f34f90083d5259f9a1a4067749d842599748d8ca03c1d9fe723124a7045c63",
- strip_prefix = "abseil-cpp-fb462224c058487763f263b7995d70efd0242c17",
+ sha256 = "006931f9705484041eed65189038f87931a87cff200bb296f94b3d42339c4cd9",
+ strip_prefix = "abseil-cpp-8ff1374008259719b54a8cb128ef951c02da164c",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel
index 9d233a30d6..934c0d9650 100644
--- a/third_party/flatbuffers/BUILD.bazel
+++ b/third_party/flatbuffers/BUILD.bazel
@@ -142,6 +142,7 @@ filegroup(
srcs = [
"include/flatbuffers/base.h",
"include/flatbuffers/flatbuffers.h",
+ "include/flatbuffers/minireflect.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/util.h",
],
diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl
index 2f25156668..235b44f7cf 100644
--- a/third_party/flatbuffers/build_defs.bzl
+++ b/third_party/flatbuffers/build_defs.bzl
@@ -92,14 +92,17 @@ def flatbuffer_library_public(
cmd = reflection_genrule_cmd,
message = "Generating flatbuffer reflection binary for %s:" % (name),
)
- native.Fileset(
- name = reflection_name,
- out = "%s_out" % reflection_name,
- entries = [
- native.FilesetEntry(files = reflection_outs),
- ],
- visibility = reflection_visiblity,
- )
+ # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
+ # Have to comment this since FilesetEntry is not supported in bazel
+ # skylark.
+ # native.Fileset(
+ # name = reflection_name,
+ # out = "%s_out" % reflection_name,
+ # entries = [
+ # native.FilesetEntry(files = reflection_outs),
+ # ],
+ # visibility = reflection_visiblity,
+ # )
def flatbuffer_cc_library(
name,
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD
index 5edf4f8120..1b9b9bf2f5 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/jpeg.BUILD
@@ -11,8 +11,8 @@ libjpegturbo_nocopts = "-[W]error"
WIN_COPTS = [
"/Ox",
- "/w14711", # function 'function' selected for inline expansion
- "/w14710", # 'function' : function not inlined
+ "-DWITH_SIMD",
+ "-wd4996",
]
libjpegturbo_copts = select({
@@ -127,6 +127,7 @@ cc_library(
":armeabi-v7a": [":simd_armv7a"],
":arm64-v8a": [":simd_armv8a"],
":linux_ppc64le": [":simd_altivec"],
+ ":windows": [":simd_win_x86_64"],
"//conditions:default": [":simd_none"],
}),
)
@@ -351,6 +352,140 @@ cc_library(
)
cc_library(
+ name = "simd_win_x86_64",
+ srcs = [
+ "jchuff.h",
+ "jconfig.h",
+ "jconfigint.h",
+ "jdct.h",
+ "jerror.h",
+ "jinclude.h",
+ "jmorecfg.h",
+ "jpegint.h",
+ "jpeglib.h",
+ "jsimd.h",
+ "jsimddct.h",
+ "simd/jsimd.h",
+ "simd/x86_64/jsimd.c",
+ "simd/x86_64/jccolor-avx2.obj",
+ "simd/x86_64/jccolor-sse2.obj",
+ "simd/x86_64/jcgray-avx2.obj",
+ "simd/x86_64/jcgray-sse2.obj",
+ "simd/x86_64/jchuff-sse2.obj",
+ "simd/x86_64/jcphuff-sse2.obj",
+ "simd/x86_64/jcsample-avx2.obj",
+ "simd/x86_64/jcsample-sse2.obj",
+ "simd/x86_64/jdcolor-avx2.obj",
+ "simd/x86_64/jdcolor-sse2.obj",
+ "simd/x86_64/jdmerge-avx2.obj",
+ "simd/x86_64/jdmerge-sse2.obj",
+ "simd/x86_64/jdsample-avx2.obj",
+ "simd/x86_64/jdsample-sse2.obj",
+ "simd/x86_64/jfdctflt-sse.obj",
+ "simd/x86_64/jfdctfst-sse2.obj",
+ "simd/x86_64/jfdctint-avx2.obj",
+ "simd/x86_64/jfdctint-sse2.obj",
+ "simd/x86_64/jidctflt-sse2.obj",
+ "simd/x86_64/jidctfst-sse2.obj",
+ "simd/x86_64/jidctint-avx2.obj",
+ "simd/x86_64/jidctint-sse2.obj",
+ "simd/x86_64/jidctred-sse2.obj",
+ "simd/x86_64/jquantf-sse2.obj",
+ "simd/x86_64/jquanti-avx2.obj",
+ "simd/x86_64/jquanti-sse2.obj",
+ "simd/x86_64/jsimdcpu.obj",
+ ],
+ copts = libjpegturbo_copts,
+)
+
+genrule(
+ name = "simd_win_x86_64_assemble",
+ srcs = [
+ "jconfig.h",
+ "jconfigint.h",
+ "simd/x86_64/jccolext-avx2.asm",
+ "simd/x86_64/jccolext-sse2.asm",
+ "simd/x86_64/jccolor-avx2.asm",
+ "simd/x86_64/jccolor-sse2.asm",
+ "simd/x86_64/jcgray-avx2.asm",
+ "simd/x86_64/jcgray-sse2.asm",
+ "simd/x86_64/jcgryext-avx2.asm",
+ "simd/x86_64/jcgryext-sse2.asm",
+ "simd/x86_64/jchuff-sse2.asm",
+ "simd/x86_64/jcphuff-sse2.asm",
+ "simd/x86_64/jcsample-avx2.asm",
+ "simd/x86_64/jcsample-sse2.asm",
+ "simd/x86_64/jdcolext-avx2.asm",
+ "simd/x86_64/jdcolext-sse2.asm",
+ "simd/x86_64/jdcolor-avx2.asm",
+ "simd/x86_64/jdcolor-sse2.asm",
+ "simd/x86_64/jdmerge-avx2.asm",
+ "simd/x86_64/jdmerge-sse2.asm",
+ "simd/x86_64/jdmrgext-avx2.asm",
+ "simd/x86_64/jdmrgext-sse2.asm",
+ "simd/x86_64/jdsample-avx2.asm",
+ "simd/x86_64/jdsample-sse2.asm",
+ "simd/x86_64/jfdctflt-sse.asm",
+ "simd/x86_64/jfdctfst-sse2.asm",
+ "simd/x86_64/jfdctint-avx2.asm",
+ "simd/x86_64/jfdctint-sse2.asm",
+ "simd/x86_64/jidctflt-sse2.asm",
+ "simd/x86_64/jidctfst-sse2.asm",
+ "simd/x86_64/jidctint-avx2.asm",
+ "simd/x86_64/jidctint-sse2.asm",
+ "simd/x86_64/jidctred-sse2.asm",
+ "simd/x86_64/jquantf-sse2.asm",
+ "simd/x86_64/jquanti-avx2.asm",
+ "simd/x86_64/jquanti-sse2.asm",
+ "simd/x86_64/jsimdcpu.asm",
+ "simd/nasm/jcolsamp.inc",
+ "simd/nasm/jdct.inc",
+ "simd/nasm/jpeg_nbits_table.inc",
+ "simd/nasm/jsimdcfg.inc",
+ "simd/nasm/jsimdcfg.inc.h",
+ "simd/nasm/jsimdext.inc",
+ ],
+ outs = [
+ "simd/x86_64/jccolor-avx2.obj",
+ "simd/x86_64/jccolor-sse2.obj",
+ "simd/x86_64/jcgray-avx2.obj",
+ "simd/x86_64/jcgray-sse2.obj",
+ "simd/x86_64/jchuff-sse2.obj",
+ "simd/x86_64/jcphuff-sse2.obj",
+ "simd/x86_64/jcsample-avx2.obj",
+ "simd/x86_64/jcsample-sse2.obj",
+ "simd/x86_64/jdcolor-avx2.obj",
+ "simd/x86_64/jdcolor-sse2.obj",
+ "simd/x86_64/jdmerge-avx2.obj",
+ "simd/x86_64/jdmerge-sse2.obj",
+ "simd/x86_64/jdsample-avx2.obj",
+ "simd/x86_64/jdsample-sse2.obj",
+ "simd/x86_64/jfdctflt-sse.obj",
+ "simd/x86_64/jfdctfst-sse2.obj",
+ "simd/x86_64/jfdctint-avx2.obj",
+ "simd/x86_64/jfdctint-sse2.obj",
+ "simd/x86_64/jidctflt-sse2.obj",
+ "simd/x86_64/jidctfst-sse2.obj",
+ "simd/x86_64/jidctint-avx2.obj",
+ "simd/x86_64/jidctint-sse2.obj",
+ "simd/x86_64/jidctred-sse2.obj",
+ "simd/x86_64/jquantf-sse2.obj",
+ "simd/x86_64/jquanti-avx2.obj",
+ "simd/x86_64/jquanti-sse2.obj",
+ "simd/x86_64/jsimdcpu.obj",
+ ],
+ cmd = "for out in $(OUTS); do\n" +
+ " $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" +
+ " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" +
+ " -I $$(dirname $(location simd/nasm/jdct.inc))/" +
+ " -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" +
+ " -o $$out" +
+ " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" +
+ "done",
+ tools = ["@nasm"],
+)
+
+cc_library(
name = "simd_none",
srcs = [
"jchuff.h",
diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index d493a3c476..54ca86f327 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -150,6 +150,35 @@ def expand_cmake_vars(name, src, dst, cmake_vars):
# The set of CMake variables common to all targets.
cmake_vars = {
+ # LLVM features
+ "ENABLE_BACKTRACES": 1,
+ "LLVM_BINDIR": "/dev/null",
+ "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0,
+ "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0,
+ "LLVM_ENABLE_THREADS": 1,
+ "LLVM_ENABLE_ZLIB": 1,
+ "LLVM_HAS_ATOMICS": 1,
+ "LLVM_INCLUDEDIR": "/dev/null",
+ "LLVM_INFODIR": "/dev/null",
+ "LLVM_MANDIR": "/dev/null",
+ "LLVM_NATIVE_TARGET": 1,
+ "LLVM_NATIVE_TARGETINFO": 1,
+ "LLVM_NATIVE_TARGETMC": 1,
+ "LLVM_NATIVE_ASMPRINTER": 1,
+ "LLVM_NATIVE_ASMPARSER": 1,
+ "LLVM_NATIVE_DISASSEMBLER": 1,
+ "LLVM_PREFIX": "/dev/null",
+ "LLVM_VERSION_MAJOR": 0,
+ "LLVM_VERSION_MINOR": 0,
+ "LLVM_VERSION_PATCH": 0,
+ "PACKAGE_NAME": "llvm",
+ "PACKAGE_STRING": "llvm tensorflow-trunk",
+ "PACKAGE_VERSION": "tensorflow-trunk",
+ "RETSIGTYPE": "void",
+}
+
+# The set of CMake variables common to POSIX targets.
+posix_cmake_vars = {
# Headers
"HAVE_DIRENT_H": 1,
"HAVE_DLFCN_H": 1,
@@ -206,32 +235,8 @@ cmake_vars = {
"HAVE__UNWIND_BACKTRACE": 1,
# LLVM features
- "ENABLE_BACKTRACES": 1,
- "LLVM_BINDIR": "/dev/null",
- "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0,
- "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0,
- "LLVM_ENABLE_THREADS": 1,
- "LLVM_ENABLE_ZLIB": 1,
- "LLVM_HAS_ATOMICS": 1,
- "LLVM_INCLUDEDIR": "/dev/null",
- "LLVM_INFODIR": "/dev/null",
- "LLVM_MANDIR": "/dev/null",
- "LLVM_NATIVE_TARGET": 1,
- "LLVM_NATIVE_TARGETINFO": 1,
- "LLVM_NATIVE_TARGETMC": 1,
- "LLVM_NATIVE_ASMPRINTER": 1,
- "LLVM_NATIVE_ASMPARSER": 1,
- "LLVM_NATIVE_DISASSEMBLER": 1,
"LLVM_ON_UNIX": 1,
- "LLVM_PREFIX": "/dev/null",
- "LLVM_VERSION_MAJOR": 0,
- "LLVM_VERSION_MINOR": 0,
- "LLVM_VERSION_PATCH": 0,
"LTDL_SHLIB_EXT": ".so",
- "PACKAGE_NAME": "llvm",
- "PACKAGE_STRING": "llvm tensorflow-trunk",
- "PACKAGE_VERSION": "tensorflow-trunk",
- "RETSIGTYPE": "void",
}
# CMake variables specific to the Linux platform
@@ -247,6 +252,40 @@ darwin_cmake_vars = {
"HAVE_MALLOC_MALLOC_H": 1,
}
+# CMake variables specific to the Windows platform.
+win32_cmake_vars = {
+ # Headers
+ "HAVE_ERRNO_H": 1,
+ "HAVE_EXECINFO_H": 1,
+ "HAVE_FCNTL_H": 1,
+ "HAVE_FENV_H": 1,
+ "HAVE_INTTYPES_H": 1,
+ "HAVE_MALLOC_H": 1,
+ "HAVE_SIGNAL_H": 1,
+ "HAVE_STDINT_H": 1,
+ "HAVE_SYS_STAT_H": 1,
+ "HAVE_SYS_TYPES_H": 1,
+ "HAVE_ZLIB_H": 1,
+
+ # Features
+ "BACKTRACE_HEADER": "execinfo.h",
+ "HAVE_GETCWD": 1,
+ "HAVE_INT64_T": 1,
+ "HAVE_STRERROR": 1,
+ "HAVE_STRTOLL": 1,
+ "HAVE_SYSCONF": 1,
+ "HAVE_UINT64_T": 1,
+ "HAVE__CHSIZE_S": 1,
+ "HAVE___CHKSTK": 1,
+
+ # MSVC specific
+ "stricmp": "_stricmp",
+ "strdup": "_strdup",
+
+ # LLVM features
+ "LTDL_SHLIB_EXT": ".dll",
+}
+
# Select a set of CMake variables based on the platform.
# TODO(phawkins): use a better method to select the right host triple, rather
# than hardcoding x86_64.
@@ -255,6 +294,7 @@ llvm_all_cmake_vars = select({
_dict_add(
cmake_vars,
llvm_target_cmake_vars("X86", "x86_64-apple-darwin"),
+ posix_cmake_vars,
darwin_cmake_vars,
),
),
@@ -262,35 +302,111 @@ llvm_all_cmake_vars = select({
_dict_add(
cmake_vars,
llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu"),
+ posix_cmake_vars,
linux_cmake_vars,
),
),
+ "@org_tensorflow//tensorflow:windows": cmake_var_string(
+ _dict_add(
+ cmake_vars,
+ llvm_target_cmake_vars("X86", "x86_64-pc-win32"),
+ win32_cmake_vars,
+ ),
+ ),
"//conditions:default": cmake_var_string(
_dict_add(
cmake_vars,
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu"),
+ posix_cmake_vars,
linux_cmake_vars,
),
),
})
-llvm_linkopts = ["-ldl", "-lm", "-lpthread"]
+llvm_linkopts = select({
+ "@org_tensorflow//tensorflow:windows": [],
+ "//conditions:default": ["-ldl", "-lm", "-lpthread"],
+})
-llvm_defines = [
+llvm_defines = select({
+ "@org_tensorflow//tensorflow:windows": [
+ "_CRT_SECURE_NO_DEPRECATE",
+ "_CRT_SECURE_NO_WARNINGS",
+ "_CRT_NONSTDC_NO_DEPRECATE",
+ "_CRT_NONSTDC_NO_WARNINGS",
+ "_SCL_SECURE_NO_DEPRECATE",
+ "_SCL_SECURE_NO_WARNINGS",
+ "UNICODE",
+ "_UNICODE",
+ ],
+ "//conditions:default": ["_DEBUG"],
+}) + [
"LLVM_ENABLE_STATS",
"__STDC_LIMIT_MACROS",
"__STDC_CONSTANT_MACROS",
"__STDC_FORMAT_MACROS",
- "_DEBUG",
"LLVM_BUILD_GLOBAL_ISEL",
]
-llvm_copts = []
+llvm_copts = select({
+ "@org_tensorflow//tensorflow:windows": [
+ "-Zc:inline",
+ "-Zc:strictStrings",
+ "-Zc:rvalueCast",
+ "-Oi",
+ "-wd4141",
+ "-wd4146",
+ "-wd4180",
+ "-wd4244",
+ "-wd4258",
+ "-wd4267",
+ "-wd4291",
+ "-wd4345",
+ "-wd4351",
+ "-wd4355",
+ "-wd4456",
+ "-wd4457",
+ "-wd4458",
+ "-wd4459",
+ "-wd4503",
+ "-wd4624",
+ "-wd4722",
+ "-wd4800",
+ "-wd4100",
+ "-wd4127",
+ "-wd4512",
+ "-wd4505",
+ "-wd4610",
+ "-wd4510",
+ "-wd4702",
+ "-wd4245",
+ "-wd4706",
+ "-wd4310",
+ "-wd4701",
+ "-wd4703",
+ "-wd4389",
+ "-wd4611",
+ "-wd4805",
+ "-wd4204",
+ "-wd4577",
+ "-wd4091",
+ "-wd4592",
+ "-wd4319",
+ "-wd4324",
+ "-w14062",
+ "-we4238",
+ ],
+ "//conditions:default": [],
+})
# Platform specific sources for libSupport.
def llvm_support_platform_specific_srcs_glob():
return select({
+ "@org_tensorflow//tensorflow:windows": native.glob([
+ "lib/Support/Windows/*.inc",
+ "lib/Support/Windows/*.h",
+ ]),
"//conditions:default": native.glob([
"lib/Support/Unix/*.inc",
"lib/Support/Unix/*.h",
diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD
index 2b877883b9..d746a65e7e 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm.BUILD
@@ -133,7 +133,10 @@ cc_binary(
"x86/regs.c",
"x86/regs.h",
"x86/regvals.c",
- ],
+ ] + select({
+ ":windows": ["config/msvc.h"],
+ "//conditions:default": [],
+ }),
includes = [
"asm",
"include",
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index ec1006fe23..4303751452 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -20,3 +20,18 @@ platform(
value:"docker://gcr.io/asci-toolchain/nosla-ubuntu16_04-tf@sha256:495a025ed5e273cfa5d53357ef93ac20500c008994e0be106c509f51555fb93c"
}""",
)
+
+platform(
+ name = "rbe_cuda9.0-cudnn7-ubuntu14.04",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains//constraints:xenial",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:ae58329b961e7c17d89725bf8fd72dfbd5850f4f3313de58e0cafbf5b0343735"
+ }""",
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/WORKSPACE b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/WORKSPACE
new file mode 100644
index 0000000000..b61f572d6d
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for cuda_configure rule
+workspace(name = "local_config_cuda")
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
new file mode 100755
index 0000000000..2d3e41127d
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
@@ -0,0 +1,1268 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_nvcc",
+ values = {
+ "define": "using_cuda_nvcc=true",
+ },
+)
+
+config_setting(
+ name = "using_clang",
+ values = {
+ "define": "using_cuda_clang=true",
+ },
+)
+
+# Equivalent to using_clang && -c opt.
+config_setting(
+ name = "using_clang_opt",
+ values = {
+ "define": "using_cuda_clang=true",
+ "compilation_mode": "opt",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {"cpu": "darwin"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "freebsd",
+ values = {"cpu": "freebsd"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-include",
+ ":cudnn-include",
+ ],
+ includes = [
+ ".",
+ "cuda/include",
+ "cuda/include/crt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart_static",
+ srcs = ["cuda/lib/libcudart_static.a"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkopts = select({
+ ":freebsd": [],
+ "//conditions:default": ["-ldl"],
+ }) + [
+ "-lpthread",
+ "-lrt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_driver",
+ srcs = ["cuda/lib/libcuda.so"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart",
+ srcs = ["cuda/lib/libcudart.so.9.0"],
+ data = ["cuda/lib/libcudart.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cublas",
+ srcs = ["cuda/lib/libcublas.so.9.0"],
+ data = ["cuda/lib/libcublas.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cusolver",
+ srcs = ["cuda/lib/libcusolver.so.9.0"],
+ data = ["cuda/lib/libcusolver.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkopts = ["-lgomp"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn",
+ srcs = ["cuda/lib/libcudnn.so.7"],
+ data = ["cuda/lib/libcudnn.so.7"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cufft",
+ srcs = ["cuda/lib/libcufft.so.9.0"],
+ data = ["cuda/lib/libcufft.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "curand",
+ srcs = ["cuda/lib/libcurand.so.9.0"],
+ data = ["cuda/lib/libcurand.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cublas",
+ ":cuda_headers",
+ ":cudart",
+ ":cudnn",
+ ":cufft",
+ ":curand",
+ ],
+)
+
+cc_library(
+ name = "cupti_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-extras",
+ ],
+ includes = [
+ ".",
+ "cuda/extras/CUPTI/include/",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cupti_dsos",
+ data = ["cuda/lib/libcupti.so.9.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "libdevice_root",
+ data = [":cuda-nvvm"],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "cuda-include",
+ outs = [
+ "cuda/include/CL/cl.h",
+ "cuda/include/CL/cl.hpp",
+ "cuda/include/CL/cl_egl.h",
+ "cuda/include/CL/cl_ext.h",
+ "cuda/include/CL/cl_gl.h",
+ "cuda/include/CL/cl_gl_ext.h",
+ "cuda/include/CL/cl_platform.h",
+ "cuda/include/CL/opencl.h",
+ "cuda/include/builtin_types.h",
+ "cuda/include/channel_descriptor.h",
+ "cuda/include/common_functions.h",
+ "cuda/include/cooperative_groups.h",
+ "cuda/include/cooperative_groups_helpers.h",
+ "cuda/include/crt/common_functions.h",
+ "cuda/include/crt/device_double_functions.h",
+ "cuda/include/crt/device_double_functions.hpp",
+ "cuda/include/crt/device_functions.h",
+ "cuda/include/crt/device_functions.hpp",
+ "cuda/include/crt/func_macro.h",
+ "cuda/include/crt/host_config.h",
+ "cuda/include/crt/host_defines.h",
+ "cuda/include/crt/host_runtime.h",
+ "cuda/include/crt/math_functions.h",
+ "cuda/include/crt/math_functions.hpp",
+ "cuda/include/crt/mma.h",
+ "cuda/include/crt/mma.hpp",
+ "cuda/include/crt/nvfunctional",
+ "cuda/include/crt/sm_70_rt.h",
+ "cuda/include/crt/sm_70_rt.hpp",
+ "cuda/include/crt/storage_class.h",
+ "cuda/include/cuComplex.h",
+ "cuda/include/cublas.h",
+ "cuda/include/cublasXt.h",
+ "cuda/include/cublas_api.h",
+ "cuda/include/cublas_v2.h",
+ "cuda/include/cuda.h",
+ "cuda/include/cudaEGL.h",
+ "cuda/include/cudaGL.h",
+ "cuda/include/cudaProfiler.h",
+ "cuda/include/cudaVDPAU.h",
+ "cuda/include/cuda_device_runtime_api.h",
+ "cuda/include/cuda_fp16.h",
+ "cuda/include/cuda_fp16.hpp",
+ "cuda/include/cuda_gl_interop.h",
+ "cuda/include/cuda_occupancy.h",
+ "cuda/include/cuda_profiler_api.h",
+ "cuda/include/cuda_runtime.h",
+ "cuda/include/cuda_runtime_api.h",
+ "cuda/include/cuda_surface_types.h",
+ "cuda/include/cuda_texture_types.h",
+ "cuda/include/cuda_vdpau_interop.h",
+ "cuda/include/cudalibxt.h",
+ "cuda/include/cufft.h",
+ "cuda/include/cufftXt.h",
+ "cuda/include/cufftw.h",
+ "cuda/include/curand.h",
+ "cuda/include/curand_discrete.h",
+ "cuda/include/curand_discrete2.h",
+ "cuda/include/curand_globals.h",
+ "cuda/include/curand_kernel.h",
+ "cuda/include/curand_lognormal.h",
+ "cuda/include/curand_mrg32k3a.h",
+ "cuda/include/curand_mtgp32.h",
+ "cuda/include/curand_mtgp32_host.h",
+ "cuda/include/curand_mtgp32_kernel.h",
+ "cuda/include/curand_mtgp32dc_p_11213.h",
+ "cuda/include/curand_normal.h",
+ "cuda/include/curand_normal_static.h",
+ "cuda/include/curand_philox4x32_x.h",
+ "cuda/include/curand_poisson.h",
+ "cuda/include/curand_precalc.h",
+ "cuda/include/curand_uniform.h",
+ "cuda/include/cusolverDn.h",
+ "cuda/include/cusolverRf.h",
+ "cuda/include/cusolverSp.h",
+ "cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h",
+ "cuda/include/cusolver_common.h",
+ "cuda/include/cusparse.h",
+ "cuda/include/cusparse_v2.h",
+ "cuda/include/device_atomic_functions.h",
+ "cuda/include/device_atomic_functions.hpp",
+ "cuda/include/device_double_functions.h",
+ "cuda/include/device_double_functions.hpp",
+ "cuda/include/device_functions.h",
+ "cuda/include/device_functions.hpp",
+ "cuda/include/device_functions_decls.h",
+ "cuda/include/device_launch_parameters.h",
+ "cuda/include/device_types.h",
+ "cuda/include/driver_functions.h",
+ "cuda/include/driver_types.h",
+ "cuda/include/dynlink_cuda.h",
+ "cuda/include/dynlink_cuda_cuda.h",
+ "cuda/include/dynlink_cuviddec.h",
+ "cuda/include/dynlink_nvcuvid.h",
+ "cuda/include/fatBinaryCtl.h",
+ "cuda/include/fatbinary.h",
+ "cuda/include/host_config.h",
+ "cuda/include/host_defines.h",
+ "cuda/include/library_types.h",
+ "cuda/include/math_constants.h",
+ "cuda/include/math_functions.h",
+ "cuda/include/math_functions.hpp",
+ "cuda/include/math_functions_dbl_ptx3.h",
+ "cuda/include/math_functions_dbl_ptx3.hpp",
+ "cuda/include/mma.h",
+ "cuda/include/npp.h",
+ "cuda/include/nppcore.h",
+ "cuda/include/nppdefs.h",
+ "cuda/include/nppi.h",
+ "cuda/include/nppi_arithmetic_and_logical_operations.h",
+ "cuda/include/nppi_color_conversion.h",
+ "cuda/include/nppi_compression_functions.h",
+ "cuda/include/nppi_computer_vision.h",
+ "cuda/include/nppi_data_exchange_and_initialization.h",
+ "cuda/include/nppi_filtering_functions.h",
+ "cuda/include/nppi_geometry_transforms.h",
+ "cuda/include/nppi_linear_transforms.h",
+ "cuda/include/nppi_morphological_operations.h",
+ "cuda/include/nppi_statistics_functions.h",
+ "cuda/include/nppi_support_functions.h",
+ "cuda/include/nppi_threshold_and_compare_operations.h",
+ "cuda/include/npps.h",
+ "cuda/include/npps_arithmetic_and_logical_operations.h",
+ "cuda/include/npps_conversion_functions.h",
+ "cuda/include/npps_filtering_functions.h",
+ "cuda/include/npps_initialization.h",
+ "cuda/include/npps_statistics_functions.h",
+ "cuda/include/npps_support_functions.h",
+ "cuda/include/nppversion.h",
+ "cuda/include/nvToolsExt.h",
+ "cuda/include/nvToolsExtCuda.h",
+ "cuda/include/nvToolsExtCudaRt.h",
+ "cuda/include/nvToolsExtMeta.h",
+ "cuda/include/nvToolsExtSync.h",
+ "cuda/include/nvblas.h",
+ "cuda/include/nvfunctional",
+ "cuda/include/nvgraph.h",
+ "cuda/include/nvml.h",
+ "cuda/include/nvrtc.h",
+ "cuda/include/sm_20_atomic_functions.h",
+ "cuda/include/sm_20_atomic_functions.hpp",
+ "cuda/include/sm_20_intrinsics.h",
+ "cuda/include/sm_20_intrinsics.hpp",
+ "cuda/include/sm_30_intrinsics.h",
+ "cuda/include/sm_30_intrinsics.hpp",
+ "cuda/include/sm_32_atomic_functions.h",
+ "cuda/include/sm_32_atomic_functions.hpp",
+ "cuda/include/sm_32_intrinsics.h",
+ "cuda/include/sm_32_intrinsics.hpp",
+ "cuda/include/sm_35_atomic_functions.h",
+ "cuda/include/sm_35_intrinsics.h",
+ "cuda/include/sm_60_atomic_functions.h",
+ "cuda/include/sm_60_atomic_functions.hpp",
+ "cuda/include/sm_61_intrinsics.h",
+ "cuda/include/sm_61_intrinsics.hpp",
+ "cuda/include/sobol_direction_vectors.h",
+ "cuda/include/surface_functions.h",
+ "cuda/include/surface_functions.hpp",
+ "cuda/include/surface_indirect_functions.h",
+ "cuda/include/surface_indirect_functions.hpp",
+ "cuda/include/surface_types.h",
+ "cuda/include/texture_fetch_functions.h",
+ "cuda/include/texture_fetch_functions.hpp",
+ "cuda/include/texture_indirect_functions.h",
+ "cuda/include/texture_indirect_functions.hpp",
+ "cuda/include/texture_types.h",
+ "cuda/include/thrust/adjacent_difference.h",
+ "cuda/include/thrust/advance.h",
+ "cuda/include/thrust/binary_search.h",
+ "cuda/include/thrust/complex.h",
+ "cuda/include/thrust/copy.h",
+ "cuda/include/thrust/count.h",
+ "cuda/include/thrust/detail/adjacent_difference.inl",
+ "cuda/include/thrust/detail/advance.inl",
+ "cuda/include/thrust/detail/allocator/allocator_traits.h",
+ "cuda/include/thrust/detail/allocator/allocator_traits.inl",
+ "cuda/include/thrust/detail/allocator/copy_construct_range.h",
+ "cuda/include/thrust/detail/allocator/copy_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/default_construct_range.h",
+ "cuda/include/thrust/detail/allocator/default_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/destroy_range.h",
+ "cuda/include/thrust/detail/allocator/destroy_range.inl",
+ "cuda/include/thrust/detail/allocator/fill_construct_range.h",
+ "cuda/include/thrust/detail/allocator/fill_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/malloc_allocator.h",
+ "cuda/include/thrust/detail/allocator/malloc_allocator.inl",
+ "cuda/include/thrust/detail/allocator/no_throw_allocator.h",
+ "cuda/include/thrust/detail/allocator/tagged_allocator.h",
+ "cuda/include/thrust/detail/allocator/tagged_allocator.inl",
+ "cuda/include/thrust/detail/allocator/temporary_allocator.h",
+ "cuda/include/thrust/detail/allocator/temporary_allocator.inl",
+ "cuda/include/thrust/detail/binary_search.inl",
+ "cuda/include/thrust/detail/complex/arithmetic.h",
+ "cuda/include/thrust/detail/complex/c99math.h",
+ "cuda/include/thrust/detail/complex/catrig.h",
+ "cuda/include/thrust/detail/complex/catrigf.h",
+ "cuda/include/thrust/detail/complex/ccosh.h",
+ "cuda/include/thrust/detail/complex/ccoshf.h",
+ "cuda/include/thrust/detail/complex/cexp.h",
+ "cuda/include/thrust/detail/complex/cexpf.h",
+ "cuda/include/thrust/detail/complex/clog.h",
+ "cuda/include/thrust/detail/complex/clogf.h",
+ "cuda/include/thrust/detail/complex/complex.inl",
+ "cuda/include/thrust/detail/complex/cpow.h",
+ "cuda/include/thrust/detail/complex/cpowf.h",
+ "cuda/include/thrust/detail/complex/cproj.h",
+ "cuda/include/thrust/detail/complex/csinh.h",
+ "cuda/include/thrust/detail/complex/csinhf.h",
+ "cuda/include/thrust/detail/complex/csqrt.h",
+ "cuda/include/thrust/detail/complex/csqrtf.h",
+ "cuda/include/thrust/detail/complex/ctanh.h",
+ "cuda/include/thrust/detail/complex/ctanhf.h",
+ "cuda/include/thrust/detail/complex/math_private.h",
+ "cuda/include/thrust/detail/complex/stream.h",
+ "cuda/include/thrust/detail/config.h",
+ "cuda/include/thrust/detail/config/compiler.h",
+ "cuda/include/thrust/detail/config/compiler_fence.h",
+ "cuda/include/thrust/detail/config/config.h",
+ "cuda/include/thrust/detail/config/debug.h",
+ "cuda/include/thrust/detail/config/device_system.h",
+ "cuda/include/thrust/detail/config/exec_check_disable.h",
+ "cuda/include/thrust/detail/config/forceinline.h",
+ "cuda/include/thrust/detail/config/global_workarounds.h",
+ "cuda/include/thrust/detail/config/host_device.h",
+ "cuda/include/thrust/detail/config/host_system.h",
+ "cuda/include/thrust/detail/config/simple_defines.h",
+ "cuda/include/thrust/detail/contiguous_storage.h",
+ "cuda/include/thrust/detail/contiguous_storage.inl",
+ "cuda/include/thrust/detail/copy.h",
+ "cuda/include/thrust/detail/copy.inl",
+ "cuda/include/thrust/detail/copy_if.h",
+ "cuda/include/thrust/detail/copy_if.inl",
+ "cuda/include/thrust/detail/count.inl",
+ "cuda/include/thrust/detail/cstdint.h",
+ "cuda/include/thrust/detail/device_delete.inl",
+ "cuda/include/thrust/detail/device_free.inl",
+ "cuda/include/thrust/detail/device_malloc.inl",
+ "cuda/include/thrust/detail/device_new.inl",
+ "cuda/include/thrust/detail/device_ptr.inl",
+ "cuda/include/thrust/detail/device_reference.inl",
+ "cuda/include/thrust/detail/device_vector.inl",
+ "cuda/include/thrust/detail/dispatch/is_trivial_copy.h",
+ "cuda/include/thrust/detail/distance.inl",
+ "cuda/include/thrust/detail/equal.inl",
+ "cuda/include/thrust/detail/execute_with_allocator.h",
+ "cuda/include/thrust/detail/execution_policy.h",
+ "cuda/include/thrust/detail/extrema.inl",
+ "cuda/include/thrust/detail/fill.inl",
+ "cuda/include/thrust/detail/find.inl",
+ "cuda/include/thrust/detail/for_each.inl",
+ "cuda/include/thrust/detail/function.h",
+ "cuda/include/thrust/detail/functional.inl",
+ "cuda/include/thrust/detail/functional/actor.h",
+ "cuda/include/thrust/detail/functional/actor.inl",
+ "cuda/include/thrust/detail/functional/argument.h",
+ "cuda/include/thrust/detail/functional/composite.h",
+ "cuda/include/thrust/detail/functional/operators.h",
+ "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h",
+ "cuda/include/thrust/detail/functional/operators/assignment_operator.h",
+ "cuda/include/thrust/detail/functional/operators/bitwise_operators.h",
+ "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h",
+ "cuda/include/thrust/detail/functional/operators/logical_operators.h",
+ "cuda/include/thrust/detail/functional/operators/operator_adaptors.h",
+ "cuda/include/thrust/detail/functional/operators/relational_operators.h",
+ "cuda/include/thrust/detail/functional/placeholder.h",
+ "cuda/include/thrust/detail/functional/value.h",
+ "cuda/include/thrust/detail/gather.inl",
+ "cuda/include/thrust/detail/generate.inl",
+ "cuda/include/thrust/detail/get_iterator_value.h",
+ "cuda/include/thrust/detail/host_vector.inl",
+ "cuda/include/thrust/detail/inner_product.inl",
+ "cuda/include/thrust/detail/integer_math.h",
+ "cuda/include/thrust/detail/integer_traits.h",
+ "cuda/include/thrust/detail/internal_functional.h",
+ "cuda/include/thrust/detail/logical.inl",
+ "cuda/include/thrust/detail/malloc_and_free.h",
+ "cuda/include/thrust/detail/merge.inl",
+ "cuda/include/thrust/detail/minmax.h",
+ "cuda/include/thrust/detail/mismatch.inl",
+ "cuda/include/thrust/detail/mpl/math.h",
+ "cuda/include/thrust/detail/numeric_traits.h",
+ "cuda/include/thrust/detail/overlapped_copy.h",
+ "cuda/include/thrust/detail/pair.inl",
+ "cuda/include/thrust/detail/partition.inl",
+ "cuda/include/thrust/detail/pointer.h",
+ "cuda/include/thrust/detail/pointer.inl",
+ "cuda/include/thrust/detail/range/head_flags.h",
+ "cuda/include/thrust/detail/range/tail_flags.h",
+ "cuda/include/thrust/detail/raw_pointer_cast.h",
+ "cuda/include/thrust/detail/raw_reference_cast.h",
+ "cuda/include/thrust/detail/reduce.inl",
+ "cuda/include/thrust/detail/reference.h",
+ "cuda/include/thrust/detail/reference.inl",
+ "cuda/include/thrust/detail/reference_forward_declaration.h",
+ "cuda/include/thrust/detail/remove.inl",
+ "cuda/include/thrust/detail/replace.inl",
+ "cuda/include/thrust/detail/reverse.inl",
+ "cuda/include/thrust/detail/scan.inl",
+ "cuda/include/thrust/detail/scatter.inl",
+ "cuda/include/thrust/detail/seq.h",
+ "cuda/include/thrust/detail/sequence.inl",
+ "cuda/include/thrust/detail/set_operations.inl",
+ "cuda/include/thrust/detail/sort.inl",
+ "cuda/include/thrust/detail/static_assert.h",
+ "cuda/include/thrust/detail/static_map.h",
+ "cuda/include/thrust/detail/swap.h",
+ "cuda/include/thrust/detail/swap.inl",
+ "cuda/include/thrust/detail/swap_ranges.inl",
+ "cuda/include/thrust/detail/tabulate.inl",
+ "cuda/include/thrust/detail/temporary_array.h",
+ "cuda/include/thrust/detail/temporary_array.inl",
+ "cuda/include/thrust/detail/temporary_buffer.h",
+ "cuda/include/thrust/detail/transform.inl",
+ "cuda/include/thrust/detail/transform_reduce.inl",
+ "cuda/include/thrust/detail/transform_scan.inl",
+ "cuda/include/thrust/detail/trivial_sequence.h",
+ "cuda/include/thrust/detail/tuple.inl",
+ "cuda/include/thrust/detail/tuple_meta_transform.h",
+ "cuda/include/thrust/detail/tuple_transform.h",
+ "cuda/include/thrust/detail/type_traits.h",
+ "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h",
+ "cuda/include/thrust/detail/type_traits/function_traits.h",
+ "cuda/include/thrust/detail/type_traits/has_member_function.h",
+ "cuda/include/thrust/detail/type_traits/has_nested_type.h",
+ "cuda/include/thrust/detail/type_traits/has_trivial_assign.h",
+ "cuda/include/thrust/detail/type_traits/is_call_possible.h",
+ "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h",
+ "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h",
+ "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h",
+ "cuda/include/thrust/detail/type_traits/minimum_type.h",
+ "cuda/include/thrust/detail/type_traits/pointer_traits.h",
+ "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h",
+ "cuda/include/thrust/detail/uninitialized_copy.inl",
+ "cuda/include/thrust/detail/uninitialized_fill.inl",
+ "cuda/include/thrust/detail/unique.inl",
+ "cuda/include/thrust/detail/use_default.h",
+ "cuda/include/thrust/detail/util/align.h",
+ "cuda/include/thrust/detail/util/blocking.h",
+ "cuda/include/thrust/detail/vector_base.h",
+ "cuda/include/thrust/detail/vector_base.inl",
+ "cuda/include/thrust/device_allocator.h",
+ "cuda/include/thrust/device_delete.h",
+ "cuda/include/thrust/device_free.h",
+ "cuda/include/thrust/device_malloc.h",
+ "cuda/include/thrust/device_malloc_allocator.h",
+ "cuda/include/thrust/device_new.h",
+ "cuda/include/thrust/device_new_allocator.h",
+ "cuda/include/thrust/device_ptr.h",
+ "cuda/include/thrust/device_reference.h",
+ "cuda/include/thrust/device_vector.h",
+ "cuda/include/thrust/distance.h",
+ "cuda/include/thrust/equal.h",
+ "cuda/include/thrust/execution_policy.h",
+ "cuda/include/thrust/extrema.h",
+ "cuda/include/thrust/fill.h",
+ "cuda/include/thrust/find.h",
+ "cuda/include/thrust/for_each.h",
+ "cuda/include/thrust/functional.h",
+ "cuda/include/thrust/gather.h",
+ "cuda/include/thrust/generate.h",
+ "cuda/include/thrust/host_vector.h",
+ "cuda/include/thrust/inner_product.h",
+ "cuda/include/thrust/iterator/constant_iterator.h",
+ "cuda/include/thrust/iterator/counting_iterator.h",
+ "cuda/include/thrust/iterator/detail/any_assign.h",
+ "cuda/include/thrust/iterator/detail/any_system_tag.h",
+ "cuda/include/thrust/iterator/detail/constant_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/counting_iterator.inl",
+ "cuda/include/thrust/iterator/detail/device_system_tag.h",
+ "cuda/include/thrust/iterator/detail/discard_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/distance_from_result.h",
+ "cuda/include/thrust/iterator/detail/host_system_tag.h",
+ "cuda/include/thrust/iterator/detail/is_iterator_category.h",
+ "cuda/include/thrust/iterator/detail/is_trivial_iterator.h",
+ "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_to_system.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h",
+ "cuda/include/thrust/iterator/detail/iterator_facade_category.h",
+ "cuda/include/thrust/iterator/detail/iterator_traits.inl",
+ "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h",
+ "cuda/include/thrust/iterator/detail/join_iterator.h",
+ "cuda/include/thrust/iterator/detail/minimum_category.h",
+ "cuda/include/thrust/iterator/detail/minimum_system.h",
+ "cuda/include/thrust/iterator/detail/normal_iterator.h",
+ "cuda/include/thrust/iterator/detail/permutation_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/retag.h",
+ "cuda/include/thrust/iterator/detail/reverse_iterator.inl",
+ "cuda/include/thrust/iterator/detail/reverse_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/tagged_iterator.h",
+ "cuda/include/thrust/iterator/detail/transform_iterator.inl",
+ "cuda/include/thrust/iterator/detail/transform_output_iterator.inl",
+ "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h",
+ "cuda/include/thrust/iterator/detail/universal_categories.h",
+ "cuda/include/thrust/iterator/detail/zip_iterator.inl",
+ "cuda/include/thrust/iterator/detail/zip_iterator_base.h",
+ "cuda/include/thrust/iterator/discard_iterator.h",
+ "cuda/include/thrust/iterator/iterator_adaptor.h",
+ "cuda/include/thrust/iterator/iterator_categories.h",
+ "cuda/include/thrust/iterator/iterator_facade.h",
+ "cuda/include/thrust/iterator/iterator_traits.h",
+ "cuda/include/thrust/iterator/permutation_iterator.h",
+ "cuda/include/thrust/iterator/retag.h",
+ "cuda/include/thrust/iterator/reverse_iterator.h",
+ "cuda/include/thrust/iterator/transform_iterator.h",
+ "cuda/include/thrust/iterator/transform_output_iterator.h",
+ "cuda/include/thrust/iterator/zip_iterator.h",
+ "cuda/include/thrust/logical.h",
+ "cuda/include/thrust/memory.h",
+ "cuda/include/thrust/merge.h",
+ "cuda/include/thrust/mismatch.h",
+ "cuda/include/thrust/pair.h",
+ "cuda/include/thrust/partition.h",
+ "cuda/include/thrust/random.h",
+ "cuda/include/thrust/random/detail/discard_block_engine.inl",
+ "cuda/include/thrust/random/detail/linear_congruential_engine.inl",
+ "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h",
+ "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl",
+ "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h",
+ "cuda/include/thrust/random/detail/mod.h",
+ "cuda/include/thrust/random/detail/normal_distribution.inl",
+ "cuda/include/thrust/random/detail/normal_distribution_base.h",
+ "cuda/include/thrust/random/detail/random_core_access.h",
+ "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl",
+ "cuda/include/thrust/random/detail/uniform_int_distribution.inl",
+ "cuda/include/thrust/random/detail/uniform_real_distribution.inl",
+ "cuda/include/thrust/random/detail/xor_combine_engine.inl",
+ "cuda/include/thrust/random/detail/xor_combine_engine_max.h",
+ "cuda/include/thrust/random/discard_block_engine.h",
+ "cuda/include/thrust/random/linear_congruential_engine.h",
+ "cuda/include/thrust/random/linear_feedback_shift_engine.h",
+ "cuda/include/thrust/random/normal_distribution.h",
+ "cuda/include/thrust/random/subtract_with_carry_engine.h",
+ "cuda/include/thrust/random/uniform_int_distribution.h",
+ "cuda/include/thrust/random/uniform_real_distribution.h",
+ "cuda/include/thrust/random/xor_combine_engine.h",
+ "cuda/include/thrust/reduce.h",
+ "cuda/include/thrust/remove.h",
+ "cuda/include/thrust/replace.h",
+ "cuda/include/thrust/reverse.h",
+ "cuda/include/thrust/scan.h",
+ "cuda/include/thrust/scatter.h",
+ "cuda/include/thrust/sequence.h",
+ "cuda/include/thrust/set_operations.h",
+ "cuda/include/thrust/sort.h",
+ "cuda/include/thrust/swap.h",
+ "cuda/include/thrust/system/cpp/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/cpp/detail/assign_value.h",
+ "cuda/include/thrust/system/cpp/detail/binary_search.h",
+ "cuda/include/thrust/system/cpp/detail/copy.h",
+ "cuda/include/thrust/system/cpp/detail/copy_if.h",
+ "cuda/include/thrust/system/cpp/detail/count.h",
+ "cuda/include/thrust/system/cpp/detail/equal.h",
+ "cuda/include/thrust/system/cpp/detail/execution_policy.h",
+ "cuda/include/thrust/system/cpp/detail/extrema.h",
+ "cuda/include/thrust/system/cpp/detail/fill.h",
+ "cuda/include/thrust/system/cpp/detail/find.h",
+ "cuda/include/thrust/system/cpp/detail/for_each.h",
+ "cuda/include/thrust/system/cpp/detail/gather.h",
+ "cuda/include/thrust/system/cpp/detail/generate.h",
+ "cuda/include/thrust/system/cpp/detail/get_value.h",
+ "cuda/include/thrust/system/cpp/detail/inner_product.h",
+ "cuda/include/thrust/system/cpp/detail/iter_swap.h",
+ "cuda/include/thrust/system/cpp/detail/logical.h",
+ "cuda/include/thrust/system/cpp/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/cpp/detail/memory.inl",
+ "cuda/include/thrust/system/cpp/detail/merge.h",
+ "cuda/include/thrust/system/cpp/detail/mismatch.h",
+ "cuda/include/thrust/system/cpp/detail/par.h",
+ "cuda/include/thrust/system/cpp/detail/partition.h",
+ "cuda/include/thrust/system/cpp/detail/reduce.h",
+ "cuda/include/thrust/system/cpp/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/remove.h",
+ "cuda/include/thrust/system/cpp/detail/replace.h",
+ "cuda/include/thrust/system/cpp/detail/reverse.h",
+ "cuda/include/thrust/system/cpp/detail/scan.h",
+ "cuda/include/thrust/system/cpp/detail/scan_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/scatter.h",
+ "cuda/include/thrust/system/cpp/detail/sequence.h",
+ "cuda/include/thrust/system/cpp/detail/set_operations.h",
+ "cuda/include/thrust/system/cpp/detail/sort.h",
+ "cuda/include/thrust/system/cpp/detail/swap_ranges.h",
+ "cuda/include/thrust/system/cpp/detail/tabulate.h",
+ "cuda/include/thrust/system/cpp/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/cpp/detail/transform.h",
+ "cuda/include/thrust/system/cpp/detail/transform_reduce.h",
+ "cuda/include/thrust/system/cpp/detail/transform_scan.h",
+ "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/cpp/detail/unique.h",
+ "cuda/include/thrust/system/cpp/detail/unique_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/vector.inl",
+ "cuda/include/thrust/system/cpp/execution_policy.h",
+ "cuda/include/thrust/system/cpp/memory.h",
+ "cuda/include/thrust/system/cpp/vector.h",
+ "cuda/include/thrust/system/cuda/config.h",
+ "cuda/include/thrust/system/cuda/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/cuda/detail/assign_value.h",
+ "cuda/include/thrust/system/cuda/detail/binary_search.h",
+ "cuda/include/thrust/system/cuda/detail/copy.h",
+ "cuda/include/thrust/system/cuda/detail/copy_if.h",
+ "cuda/include/thrust/system/cuda/detail/core/agent_launcher.h",
+ "cuda/include/thrust/system/cuda/detail/core/alignment.h",
+ "cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h",
+ "cuda/include/thrust/system/cuda/detail/core/util.h",
+ "cuda/include/thrust/system/cuda/detail/count.h",
+ "cuda/include/thrust/system/cuda/detail/cross_system.h",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/cub.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/equal.h",
+ "cuda/include/thrust/system/cuda/detail/error.inl",
+ "cuda/include/thrust/system/cuda/detail/execution_policy.h",
+ "cuda/include/thrust/system/cuda/detail/extrema.h",
+ "cuda/include/thrust/system/cuda/detail/fill.h",
+ "cuda/include/thrust/system/cuda/detail/find.h",
+ "cuda/include/thrust/system/cuda/detail/for_each.h",
+ "cuda/include/thrust/system/cuda/detail/gather.h",
+ "cuda/include/thrust/system/cuda/detail/generate.h",
+ "cuda/include/thrust/system/cuda/detail/get_value.h",
+ "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h",
+ "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h",
+ "cuda/include/thrust/system/cuda/detail/inner_product.h",
+ "cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h",
+ "cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h",
+ "cuda/include/thrust/system/cuda/detail/iter_swap.h",
+ "cuda/include/thrust/system/cuda/detail/logical.h",
+ "cuda/include/thrust/system/cuda/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/cuda/detail/memory.inl",
+ "cuda/include/thrust/system/cuda/detail/memory_buffer.h",
+ "cuda/include/thrust/system/cuda/detail/merge.h",
+ "cuda/include/thrust/system/cuda/detail/mismatch.h",
+ "cuda/include/thrust/system/cuda/detail/par.h",
+ "cuda/include/thrust/system/cuda/detail/par_to_seq.h",
+ "cuda/include/thrust/system/cuda/detail/parallel_for.h",
+ "cuda/include/thrust/system/cuda/detail/partition.h",
+ "cuda/include/thrust/system/cuda/detail/reduce.h",
+ "cuda/include/thrust/system/cuda/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/remove.h",
+ "cuda/include/thrust/system/cuda/detail/replace.h",
+ "cuda/include/thrust/system/cuda/detail/reverse.h",
+ "cuda/include/thrust/system/cuda/detail/scan.h",
+ "cuda/include/thrust/system/cuda/detail/scan_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/scatter.h",
+ "cuda/include/thrust/system/cuda/detail/sequence.h",
+ "cuda/include/thrust/system/cuda/detail/set_operations.h",
+ "cuda/include/thrust/system/cuda/detail/sort.h",
+ "cuda/include/thrust/system/cuda/detail/swap_ranges.h",
+ "cuda/include/thrust/system/cuda/detail/tabulate.h",
+ "cuda/include/thrust/system/cuda/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/cuda/detail/terminate.h",
+ "cuda/include/thrust/system/cuda/detail/transform.h",
+ "cuda/include/thrust/system/cuda/detail/transform_reduce.h",
+ "cuda/include/thrust/system/cuda/detail/transform_scan.h",
+ "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/cuda/detail/unique.h",
+ "cuda/include/thrust/system/cuda/detail/unique_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/util.h",
+ "cuda/include/thrust/system/cuda/detail/vector.inl",
+ "cuda/include/thrust/system/cuda/error.h",
+ "cuda/include/thrust/system/cuda/execution_policy.h",
+ "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h",
+ "cuda/include/thrust/system/cuda/memory.h",
+ "cuda/include/thrust/system/cuda/vector.h",
+ "cuda/include/thrust/system/detail/adl/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/adl/assign_value.h",
+ "cuda/include/thrust/system/detail/adl/binary_search.h",
+ "cuda/include/thrust/system/detail/adl/copy.h",
+ "cuda/include/thrust/system/detail/adl/copy_if.h",
+ "cuda/include/thrust/system/detail/adl/count.h",
+ "cuda/include/thrust/system/detail/adl/equal.h",
+ "cuda/include/thrust/system/detail/adl/extrema.h",
+ "cuda/include/thrust/system/detail/adl/fill.h",
+ "cuda/include/thrust/system/detail/adl/find.h",
+ "cuda/include/thrust/system/detail/adl/for_each.h",
+ "cuda/include/thrust/system/detail/adl/gather.h",
+ "cuda/include/thrust/system/detail/adl/generate.h",
+ "cuda/include/thrust/system/detail/adl/get_value.h",
+ "cuda/include/thrust/system/detail/adl/inner_product.h",
+ "cuda/include/thrust/system/detail/adl/iter_swap.h",
+ "cuda/include/thrust/system/detail/adl/logical.h",
+ "cuda/include/thrust/system/detail/adl/malloc_and_free.h",
+ "cuda/include/thrust/system/detail/adl/merge.h",
+ "cuda/include/thrust/system/detail/adl/mismatch.h",
+ "cuda/include/thrust/system/detail/adl/partition.h",
+ "cuda/include/thrust/system/detail/adl/reduce.h",
+ "cuda/include/thrust/system/detail/adl/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/adl/remove.h",
+ "cuda/include/thrust/system/detail/adl/replace.h",
+ "cuda/include/thrust/system/detail/adl/reverse.h",
+ "cuda/include/thrust/system/detail/adl/scan.h",
+ "cuda/include/thrust/system/detail/adl/scan_by_key.h",
+ "cuda/include/thrust/system/detail/adl/scatter.h",
+ "cuda/include/thrust/system/detail/adl/sequence.h",
+ "cuda/include/thrust/system/detail/adl/set_operations.h",
+ "cuda/include/thrust/system/detail/adl/sort.h",
+ "cuda/include/thrust/system/detail/adl/swap_ranges.h",
+ "cuda/include/thrust/system/detail/adl/tabulate.h",
+ "cuda/include/thrust/system/detail/adl/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/adl/transform.h",
+ "cuda/include/thrust/system/detail/adl/transform_reduce.h",
+ "cuda/include/thrust/system/detail/adl/transform_scan.h",
+ "cuda/include/thrust/system/detail/adl/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/adl/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/adl/unique.h",
+ "cuda/include/thrust/system/detail/adl/unique_by_key.h",
+ "cuda/include/thrust/system/detail/bad_alloc.h",
+ "cuda/include/thrust/system/detail/errno.h",
+ "cuda/include/thrust/system/detail/error_category.inl",
+ "cuda/include/thrust/system/detail/error_code.inl",
+ "cuda/include/thrust/system/detail/error_condition.inl",
+ "cuda/include/thrust/system/detail/generic/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/generic/adjacent_difference.inl",
+ "cuda/include/thrust/system/detail/generic/advance.h",
+ "cuda/include/thrust/system/detail/generic/advance.inl",
+ "cuda/include/thrust/system/detail/generic/binary_search.h",
+ "cuda/include/thrust/system/detail/generic/binary_search.inl",
+ "cuda/include/thrust/system/detail/generic/copy.h",
+ "cuda/include/thrust/system/detail/generic/copy.inl",
+ "cuda/include/thrust/system/detail/generic/copy_if.h",
+ "cuda/include/thrust/system/detail/generic/copy_if.inl",
+ "cuda/include/thrust/system/detail/generic/count.h",
+ "cuda/include/thrust/system/detail/generic/count.inl",
+ "cuda/include/thrust/system/detail/generic/distance.h",
+ "cuda/include/thrust/system/detail/generic/distance.inl",
+ "cuda/include/thrust/system/detail/generic/equal.h",
+ "cuda/include/thrust/system/detail/generic/equal.inl",
+ "cuda/include/thrust/system/detail/generic/extrema.h",
+ "cuda/include/thrust/system/detail/generic/extrema.inl",
+ "cuda/include/thrust/system/detail/generic/fill.h",
+ "cuda/include/thrust/system/detail/generic/find.h",
+ "cuda/include/thrust/system/detail/generic/find.inl",
+ "cuda/include/thrust/system/detail/generic/for_each.h",
+ "cuda/include/thrust/system/detail/generic/gather.h",
+ "cuda/include/thrust/system/detail/generic/gather.inl",
+ "cuda/include/thrust/system/detail/generic/generate.h",
+ "cuda/include/thrust/system/detail/generic/generate.inl",
+ "cuda/include/thrust/system/detail/generic/inner_product.h",
+ "cuda/include/thrust/system/detail/generic/inner_product.inl",
+ "cuda/include/thrust/system/detail/generic/logical.h",
+ "cuda/include/thrust/system/detail/generic/memory.h",
+ "cuda/include/thrust/system/detail/generic/memory.inl",
+ "cuda/include/thrust/system/detail/generic/merge.h",
+ "cuda/include/thrust/system/detail/generic/merge.inl",
+ "cuda/include/thrust/system/detail/generic/mismatch.h",
+ "cuda/include/thrust/system/detail/generic/mismatch.inl",
+ "cuda/include/thrust/system/detail/generic/partition.h",
+ "cuda/include/thrust/system/detail/generic/partition.inl",
+ "cuda/include/thrust/system/detail/generic/reduce.h",
+ "cuda/include/thrust/system/detail/generic/reduce.inl",
+ "cuda/include/thrust/system/detail/generic/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/generic/reduce_by_key.inl",
+ "cuda/include/thrust/system/detail/generic/remove.h",
+ "cuda/include/thrust/system/detail/generic/remove.inl",
+ "cuda/include/thrust/system/detail/generic/replace.h",
+ "cuda/include/thrust/system/detail/generic/replace.inl",
+ "cuda/include/thrust/system/detail/generic/reverse.h",
+ "cuda/include/thrust/system/detail/generic/reverse.inl",
+ "cuda/include/thrust/system/detail/generic/scalar/binary_search.h",
+ "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl",
+ "cuda/include/thrust/system/detail/generic/scan.h",
+ "cuda/include/thrust/system/detail/generic/scan.inl",
+ "cuda/include/thrust/system/detail/generic/scan_by_key.h",
+ "cuda/include/thrust/system/detail/generic/scan_by_key.inl",
+ "cuda/include/thrust/system/detail/generic/scatter.h",
+ "cuda/include/thrust/system/detail/generic/scatter.inl",
+ "cuda/include/thrust/system/detail/generic/select_system.h",
+ "cuda/include/thrust/system/detail/generic/sequence.h",
+ "cuda/include/thrust/system/detail/generic/sequence.inl",
+ "cuda/include/thrust/system/detail/generic/set_operations.h",
+ "cuda/include/thrust/system/detail/generic/set_operations.inl",
+ "cuda/include/thrust/system/detail/generic/sort.h",
+ "cuda/include/thrust/system/detail/generic/sort.inl",
+ "cuda/include/thrust/system/detail/generic/swap_ranges.h",
+ "cuda/include/thrust/system/detail/generic/swap_ranges.inl",
+ "cuda/include/thrust/system/detail/generic/tabulate.h",
+ "cuda/include/thrust/system/detail/generic/tabulate.inl",
+ "cuda/include/thrust/system/detail/generic/tag.h",
+ "cuda/include/thrust/system/detail/generic/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/generic/temporary_buffer.inl",
+ "cuda/include/thrust/system/detail/generic/transform.h",
+ "cuda/include/thrust/system/detail/generic/transform.inl",
+ "cuda/include/thrust/system/detail/generic/transform_reduce.h",
+ "cuda/include/thrust/system/detail/generic/transform_reduce.inl",
+ "cuda/include/thrust/system/detail/generic/transform_scan.h",
+ "cuda/include/thrust/system/detail/generic/transform_scan.inl",
+ "cuda/include/thrust/system/detail/generic/type_traits.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl",
+ "cuda/include/thrust/system/detail/generic/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl",
+ "cuda/include/thrust/system/detail/generic/unique.h",
+ "cuda/include/thrust/system/detail/generic/unique.inl",
+ "cuda/include/thrust/system/detail/generic/unique_by_key.h",
+ "cuda/include/thrust/system/detail/generic/unique_by_key.inl",
+ "cuda/include/thrust/system/detail/internal/decompose.h",
+ "cuda/include/thrust/system/detail/sequential/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/sequential/assign_value.h",
+ "cuda/include/thrust/system/detail/sequential/binary_search.h",
+ "cuda/include/thrust/system/detail/sequential/copy.h",
+ "cuda/include/thrust/system/detail/sequential/copy.inl",
+ "cuda/include/thrust/system/detail/sequential/copy_backward.h",
+ "cuda/include/thrust/system/detail/sequential/copy_if.h",
+ "cuda/include/thrust/system/detail/sequential/count.h",
+ "cuda/include/thrust/system/detail/sequential/equal.h",
+ "cuda/include/thrust/system/detail/sequential/execution_policy.h",
+ "cuda/include/thrust/system/detail/sequential/extrema.h",
+ "cuda/include/thrust/system/detail/sequential/fill.h",
+ "cuda/include/thrust/system/detail/sequential/find.h",
+ "cuda/include/thrust/system/detail/sequential/for_each.h",
+ "cuda/include/thrust/system/detail/sequential/gather.h",
+ "cuda/include/thrust/system/detail/sequential/general_copy.h",
+ "cuda/include/thrust/system/detail/sequential/generate.h",
+ "cuda/include/thrust/system/detail/sequential/get_value.h",
+ "cuda/include/thrust/system/detail/sequential/inner_product.h",
+ "cuda/include/thrust/system/detail/sequential/insertion_sort.h",
+ "cuda/include/thrust/system/detail/sequential/iter_swap.h",
+ "cuda/include/thrust/system/detail/sequential/logical.h",
+ "cuda/include/thrust/system/detail/sequential/malloc_and_free.h",
+ "cuda/include/thrust/system/detail/sequential/merge.h",
+ "cuda/include/thrust/system/detail/sequential/merge.inl",
+ "cuda/include/thrust/system/detail/sequential/mismatch.h",
+ "cuda/include/thrust/system/detail/sequential/partition.h",
+ "cuda/include/thrust/system/detail/sequential/reduce.h",
+ "cuda/include/thrust/system/detail/sequential/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/sequential/remove.h",
+ "cuda/include/thrust/system/detail/sequential/replace.h",
+ "cuda/include/thrust/system/detail/sequential/reverse.h",
+ "cuda/include/thrust/system/detail/sequential/scan.h",
+ "cuda/include/thrust/system/detail/sequential/scan_by_key.h",
+ "cuda/include/thrust/system/detail/sequential/scatter.h",
+ "cuda/include/thrust/system/detail/sequential/sequence.h",
+ "cuda/include/thrust/system/detail/sequential/set_operations.h",
+ "cuda/include/thrust/system/detail/sequential/sort.h",
+ "cuda/include/thrust/system/detail/sequential/sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/swap_ranges.h",
+ "cuda/include/thrust/system/detail/sequential/tabulate.h",
+ "cuda/include/thrust/system/detail/sequential/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/sequential/transform.h",
+ "cuda/include/thrust/system/detail/sequential/transform_reduce.h",
+ "cuda/include/thrust/system/detail/sequential/transform_scan.h",
+ "cuda/include/thrust/system/detail/sequential/trivial_copy.h",
+ "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/sequential/unique.h",
+ "cuda/include/thrust/system/detail/sequential/unique_by_key.h",
+ "cuda/include/thrust/system/detail/system_error.inl",
+ "cuda/include/thrust/system/error_code.h",
+ "cuda/include/thrust/system/omp/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/omp/detail/assign_value.h",
+ "cuda/include/thrust/system/omp/detail/binary_search.h",
+ "cuda/include/thrust/system/omp/detail/copy.h",
+ "cuda/include/thrust/system/omp/detail/copy.inl",
+ "cuda/include/thrust/system/omp/detail/copy_if.h",
+ "cuda/include/thrust/system/omp/detail/copy_if.inl",
+ "cuda/include/thrust/system/omp/detail/count.h",
+ "cuda/include/thrust/system/omp/detail/default_decomposition.h",
+ "cuda/include/thrust/system/omp/detail/default_decomposition.inl",
+ "cuda/include/thrust/system/omp/detail/equal.h",
+ "cuda/include/thrust/system/omp/detail/execution_policy.h",
+ "cuda/include/thrust/system/omp/detail/extrema.h",
+ "cuda/include/thrust/system/omp/detail/fill.h",
+ "cuda/include/thrust/system/omp/detail/find.h",
+ "cuda/include/thrust/system/omp/detail/for_each.h",
+ "cuda/include/thrust/system/omp/detail/for_each.inl",
+ "cuda/include/thrust/system/omp/detail/gather.h",
+ "cuda/include/thrust/system/omp/detail/generate.h",
+ "cuda/include/thrust/system/omp/detail/get_value.h",
+ "cuda/include/thrust/system/omp/detail/inner_product.h",
+ "cuda/include/thrust/system/omp/detail/iter_swap.h",
+ "cuda/include/thrust/system/omp/detail/logical.h",
+ "cuda/include/thrust/system/omp/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/omp/detail/memory.inl",
+ "cuda/include/thrust/system/omp/detail/merge.h",
+ "cuda/include/thrust/system/omp/detail/mismatch.h",
+ "cuda/include/thrust/system/omp/detail/par.h",
+ "cuda/include/thrust/system/omp/detail/partition.h",
+ "cuda/include/thrust/system/omp/detail/partition.inl",
+ "cuda/include/thrust/system/omp/detail/reduce.h",
+ "cuda/include/thrust/system/omp/detail/reduce.inl",
+ "cuda/include/thrust/system/omp/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/omp/detail/reduce_by_key.inl",
+ "cuda/include/thrust/system/omp/detail/reduce_intervals.h",
+ "cuda/include/thrust/system/omp/detail/reduce_intervals.inl",
+ "cuda/include/thrust/system/omp/detail/remove.h",
+ "cuda/include/thrust/system/omp/detail/remove.inl",
+ "cuda/include/thrust/system/omp/detail/replace.h",
+ "cuda/include/thrust/system/omp/detail/reverse.h",
+ "cuda/include/thrust/system/omp/detail/scan.h",
+ "cuda/include/thrust/system/omp/detail/scan_by_key.h",
+ "cuda/include/thrust/system/omp/detail/scatter.h",
+ "cuda/include/thrust/system/omp/detail/sequence.h",
+ "cuda/include/thrust/system/omp/detail/set_operations.h",
+ "cuda/include/thrust/system/omp/detail/sort.h",
+ "cuda/include/thrust/system/omp/detail/sort.inl",
+ "cuda/include/thrust/system/omp/detail/swap_ranges.h",
+ "cuda/include/thrust/system/omp/detail/tabulate.h",
+ "cuda/include/thrust/system/omp/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/omp/detail/transform.h",
+ "cuda/include/thrust/system/omp/detail/transform_reduce.h",
+ "cuda/include/thrust/system/omp/detail/transform_scan.h",
+ "cuda/include/thrust/system/omp/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/omp/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/omp/detail/unique.h",
+ "cuda/include/thrust/system/omp/detail/unique.inl",
+ "cuda/include/thrust/system/omp/detail/unique_by_key.h",
+ "cuda/include/thrust/system/omp/detail/unique_by_key.inl",
+ "cuda/include/thrust/system/omp/detail/vector.inl",
+ "cuda/include/thrust/system/omp/execution_policy.h",
+ "cuda/include/thrust/system/omp/memory.h",
+ "cuda/include/thrust/system/omp/vector.h",
+ "cuda/include/thrust/system/system_error.h",
+ "cuda/include/thrust/system/tbb/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/tbb/detail/assign_value.h",
+ "cuda/include/thrust/system/tbb/detail/binary_search.h",
+ "cuda/include/thrust/system/tbb/detail/copy.h",
+ "cuda/include/thrust/system/tbb/detail/copy.inl",
+ "cuda/include/thrust/system/tbb/detail/copy_if.h",
+ "cuda/include/thrust/system/tbb/detail/copy_if.inl",
+ "cuda/include/thrust/system/tbb/detail/count.h",
+ "cuda/include/thrust/system/tbb/detail/equal.h",
+ "cuda/include/thrust/system/tbb/detail/execution_policy.h",
+ "cuda/include/thrust/system/tbb/detail/extrema.h",
+ "cuda/include/thrust/system/tbb/detail/fill.h",
+ "cuda/include/thrust/system/tbb/detail/find.h",
+ "cuda/include/thrust/system/tbb/detail/for_each.h",
+ "cuda/include/thrust/system/tbb/detail/for_each.inl",
+ "cuda/include/thrust/system/tbb/detail/gather.h",
+ "cuda/include/thrust/system/tbb/detail/generate.h",
+ "cuda/include/thrust/system/tbb/detail/get_value.h",
+ "cuda/include/thrust/system/tbb/detail/inner_product.h",
+ "cuda/include/thrust/system/tbb/detail/iter_swap.h",
+ "cuda/include/thrust/system/tbb/detail/logical.h",
+ "cuda/include/thrust/system/tbb/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/tbb/detail/memory.inl",
+ "cuda/include/thrust/system/tbb/detail/merge.h",
+ "cuda/include/thrust/system/tbb/detail/merge.inl",
+ "cuda/include/thrust/system/tbb/detail/mismatch.h",
+ "cuda/include/thrust/system/tbb/detail/par.h",
+ "cuda/include/thrust/system/tbb/detail/partition.h",
+ "cuda/include/thrust/system/tbb/detail/partition.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce.h",
+ "cuda/include/thrust/system/tbb/detail/reduce.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce_intervals.h",
+ "cuda/include/thrust/system/tbb/detail/remove.h",
+ "cuda/include/thrust/system/tbb/detail/remove.inl",
+ "cuda/include/thrust/system/tbb/detail/replace.h",
+ "cuda/include/thrust/system/tbb/detail/reverse.h",
+ "cuda/include/thrust/system/tbb/detail/scan.h",
+ "cuda/include/thrust/system/tbb/detail/scan.inl",
+ "cuda/include/thrust/system/tbb/detail/scan_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/scatter.h",
+ "cuda/include/thrust/system/tbb/detail/sequence.h",
+ "cuda/include/thrust/system/tbb/detail/set_operations.h",
+ "cuda/include/thrust/system/tbb/detail/sort.h",
+ "cuda/include/thrust/system/tbb/detail/sort.inl",
+ "cuda/include/thrust/system/tbb/detail/swap_ranges.h",
+ "cuda/include/thrust/system/tbb/detail/tabulate.h",
+ "cuda/include/thrust/system/tbb/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/tbb/detail/transform.h",
+ "cuda/include/thrust/system/tbb/detail/transform_reduce.h",
+ "cuda/include/thrust/system/tbb/detail/transform_scan.h",
+ "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/tbb/detail/unique.h",
+ "cuda/include/thrust/system/tbb/detail/unique.inl",
+ "cuda/include/thrust/system/tbb/detail/unique_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/unique_by_key.inl",
+ "cuda/include/thrust/system/tbb/detail/vector.inl",
+ "cuda/include/thrust/system/tbb/execution_policy.h",
+ "cuda/include/thrust/system/tbb/memory.h",
+ "cuda/include/thrust/system/tbb/vector.h",
+ "cuda/include/thrust/system_error.h",
+ "cuda/include/thrust/tabulate.h",
+ "cuda/include/thrust/transform.h",
+ "cuda/include/thrust/transform_reduce.h",
+ "cuda/include/thrust/transform_scan.h",
+ "cuda/include/thrust/tuple.h",
+ "cuda/include/thrust/uninitialized_copy.h",
+ "cuda/include/thrust/uninitialized_fill.h",
+ "cuda/include/thrust/unique.h",
+ "cuda/include/thrust/version.h",
+ "cuda/include/vector_functions.h",
+ "cuda/include/vector_functions.hpp",
+ "cuda/include/vector_types.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp "/usr/local/cuda-9.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp "/usr/local/cuda-9.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp "/usr/local/cuda-9.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp "/usr/local/cuda-9.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp "/usr/local/cuda-9.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp "/usr/local/cuda-9.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp "/usr/local/cuda-9.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp "/usr/local/cuda-9.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp "/usr/local/cuda-9.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp "/usr/local/cuda-9.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp "/usr/local/cuda-9.0/include/cooperative_groups.h" "$(@D)/cuda/include/cooperative_groups.h" && cp "/usr/local/cuda-9.0/include/cooperative_groups_helpers.h" "$(@D)/cuda/include/cooperative_groups_helpers.h" && cp "/usr/local/cuda-9.0/include/crt/common_functions.h" "$(@D)/cuda/include/crt/common_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_double_functions.h" "$(@D)/cuda/include/crt/device_double_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_double_functions.hpp" "$(@D)/cuda/include/crt/device_double_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/device_functions.h" "$(@D)/cuda/include/crt/device_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_functions.hpp" "$(@D)/cuda/include/crt/device_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp "/usr/local/cuda-9.0/include/crt/host_config.h" "$(@D)/cuda/include/crt/host_config.h" && cp "/usr/local/cuda-9.0/include/crt/host_defines.h" "$(@D)/cuda/include/crt/host_defines.h" && cp "/usr/local/cuda-9.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp "/usr/local/cuda-9.0/include/crt/math_functions.h" "$(@D)/cuda/include/crt/math_functions.h" && cp "/usr/local/cuda-9.0/include/crt/math_functions.hpp" "$(@D)/cuda/include/crt/math_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/mma.h" "$(@D)/cuda/include/crt/mma.h" && cp "/usr/local/cuda-9.0/include/crt/mma.hpp" "$(@D)/cuda/include/crt/mma.hpp" && cp "/usr/local/cuda-9.0/include/crt/nvfunctional" "$(@D)/cuda/include/crt/nvfunctional" && cp "/usr/local/cuda-9.0/include/crt/sm_70_rt.h" "$(@D)/cuda/include/crt/sm_70_rt.h" && cp "/usr/local/cuda-9.0/include/crt/sm_70_rt.hpp" "$(@D)/cuda/include/crt/sm_70_rt.hpp" && cp "/usr/local/cuda-9.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp "/usr/local/cuda-9.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp "/usr/local/cuda-9.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp "/usr/local/cuda-9.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp "/usr/local/cuda-9.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp "/usr/local/cuda-9.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp "/usr/local/cuda-9.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp "/usr/local/cuda-9.0/include/cudaEGL.h" "$(@D)/cuda/include/cudaEGL.h" && cp "/usr/local/cuda-9.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp "/usr/local/cuda-9.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp "/usr/local/cuda-9.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp "/usr/local/cuda-9.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp "/usr/local/cuda-9.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp "/usr/local/cuda-9.0/include/cuda_fp16.hpp" "$(@D)/cuda/include/cuda_fp16.hpp" && cp "/usr/local/cuda-9.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp "/usr/local/cuda-9.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp "/usr/local/cuda-9.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp "/usr/local/cuda-9.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp "/usr/local/cuda-9.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp "/usr/local/cuda-9.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp "/usr/local/cuda-9.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp "/usr/local/cuda-9.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp "/usr/local/cuda-9.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp "/usr/local/cuda-9.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp "/usr/local/cuda-9.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp "/usr/local/cuda-9.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp "/usr/local/cuda-9.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp "/usr/local/cuda-9.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp "/usr/local/cuda-9.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp "/usr/local/cuda-9.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp "/usr/local/cuda-9.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp "/usr/local/cuda-9.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp "/usr/local/cuda-9.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp "/usr/local/cuda-9.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp "/usr/local/cuda-9.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp "/usr/local/cuda-9.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp "/usr/local/cuda-9.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp "/usr/local/cuda-9.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp "/usr/local/cuda-9.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp "/usr/local/cuda-9.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp "/usr/local/cuda-9.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp "/usr/local/cuda-9.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp "/usr/local/cuda-9.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp "/usr/local/cuda-9.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp "/usr/local/cuda-9.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp "/usr/local/cuda-9.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp "/usr/local/cuda-9.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp "/usr/local/cuda-9.0/include/device_double_functions.hpp" "$(@D)/cuda/include/device_double_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp "/usr/local/cuda-9.0/include/device_functions.hpp" "$(@D)/cuda/include/device_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_functions_decls.h" "$(@D)/cuda/include/device_functions_decls.h" && cp "/usr/local/cuda-9.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp "/usr/local/cuda-9.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp "/usr/local/cuda-9.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp "/usr/local/cuda-9.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuda.h" "$(@D)/cuda/include/dynlink_cuda.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuda_cuda.h" "$(@D)/cuda/include/dynlink_cuda_cuda.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuviddec.h" "$(@D)/cuda/include/dynlink_cuviddec.h" && cp "/usr/local/cuda-9.0/include/dynlink_nvcuvid.h" "$(@D)/cuda/include/dynlink_nvcuvid.h" && cp "/usr/local/cuda-9.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp "/usr/local/cuda-9.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp "/usr/local/cuda-9.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp "/usr/local/cuda-9.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp "/usr/local/cuda-9.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp "/usr/local/cuda-9.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp "/usr/local/cuda-9.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp "/usr/local/cuda-9.0/include/math_functions.hpp" "$(@D)/cuda/include/math_functions.hpp" && cp "/usr/local/cuda-9.0/include/math_functions_dbl_ptx3.h" "$(@D)/cuda/include/math_functions_dbl_ptx3.h" && cp "/usr/local/cuda-9.0/include/math_functions_dbl_ptx3.hpp" "$(@D)/cuda/include/math_functions_dbl_ptx3.hpp" && cp "/usr/local/cuda-9.0/include/mma.h" "$(@D)/cuda/include/mma.h" && cp "/usr/local/cuda-9.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp "/usr/local/cuda-9.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp "/usr/local/cuda-9.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp "/usr/local/cuda-9.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp "/usr/local/cuda-9.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-9.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp "/usr/local/cuda-9.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp "/usr/local/cuda-9.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp "/usr/local/cuda-9.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp "/usr/local/cuda-9.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp "/usr/local/cuda-9.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp "/usr/local/cuda-9.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp "/usr/local/cuda-9.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp "/usr/local/cuda-9.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-9.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp "/usr/local/cuda-9.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp "/usr/local/cuda-9.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp "/usr/local/cuda-9.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp "/usr/local/cuda-9.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp "/usr/local/cuda-9.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp "/usr/local/cuda-9.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp "/usr/local/cuda-9.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp "/usr/local/cuda-9.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" && cp "/usr/local/cuda-9.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp "/usr/local/cuda-9.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp "/usr/local/cuda-9.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp "/usr/local/cuda-9.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp "/usr/local/cuda-9.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp "/usr/local/cuda-9.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp "/usr/local/cuda-9.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp "/usr/local/cuda-9.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp "/usr/local/cuda-9.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp "/usr/local/cuda-9.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp "/usr/local/cuda-9.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp "/usr/local/cuda-9.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp "/usr/local/cuda-9.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp "/usr/local/cuda-9.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp "/usr/local/cuda-9.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp "/usr/local/cuda-9.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp "/usr/local/cuda-9.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cpowf.h" "$(@D)/cuda/include/thrust/detail/complex/cpowf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp "/usr/local/cuda-9.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp "/usr/local/cuda-9.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp "/usr/local/cuda-9.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp "/usr/local/cuda-9.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp "/usr/local/cuda-9.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp "/usr/local/cuda-9.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp "/usr/local/cuda-9.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp "/usr/local/cuda-9.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp "/usr/local/cuda-9.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp "/usr/local/cuda-9.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp "/usr/local/cuda-9.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp "/usr/local/cuda-9.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/transform_output_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_output_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/transform_output_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_output_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp "/usr/local/cuda-9.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp "/usr/local/cuda-9.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/config.h" "$(@D)/cuda/include/thrust/system/cuda/config.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/agent_launcher.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/agent_launcher.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/alignment.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/util.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cross_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/host/mutex.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/internal/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/memory_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/memory_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/par_to_seq.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par_to_seq.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/parallel_for.h" "$(@D)/cuda/include/thrust/system/cuda/detail/parallel_for.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/util.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp "/usr/local/cuda-9.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp "/usr/local/cuda-9.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp "/usr/local/cuda-9.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp "/usr/local/cuda-9.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp "/usr/local/cuda-9.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h"
+ """,
+)
+
+genrule(
+ name = "cuda-nvvm",
+ outs = [
+ "cuda/nvvm/libdevice/libdevice.10.bc",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/nvvm/libdevice/libdevice.10.bc" "$(@D)//libdevice.10.bc"
+ """,
+)
+
+genrule(
+ name = "cuda-extras",
+ outs = [
+ "cuda/extras/CUPTI/include/GL/gl.h",
+ "cuda/extras/CUPTI/include/GL/glew.h",
+ "cuda/extras/CUPTI/include/GL/glext.h",
+ "cuda/extras/CUPTI/include/GL/glu.h",
+ "cuda/extras/CUPTI/include/GL/glut.h",
+ "cuda/extras/CUPTI/include/GL/glx.h",
+ "cuda/extras/CUPTI/include/GL/glxext.h",
+ "cuda/extras/CUPTI/include/GL/wglew.h",
+ "cuda/extras/CUPTI/include/GL/wglext.h",
+ "cuda/extras/CUPTI/include/cuda_stdint.h",
+ "cuda/extras/CUPTI/include/cupti.h",
+ "cuda/extras/CUPTI/include/cupti_activity.h",
+ "cuda/extras/CUPTI/include/cupti_callbacks.h",
+ "cuda/extras/CUPTI/include/cupti_driver_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_events.h",
+ "cuda/extras/CUPTI/include/cupti_metrics.h",
+ "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_result.h",
+ "cuda/extras/CUPTI/include/cupti_runtime_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_version.h",
+ "cuda/extras/CUPTI/include/generated_cudaGL_meta.h",
+ "cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h",
+ "cuda/extras/CUPTI/include/generated_nvtx_meta.h",
+ "cuda/extras/CUPTI/include/openacc/cupti_openacc.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h"
+ """,
+)
+
+genrule(
+ name = "cuda-lib",
+ outs = [
+ "cuda/lib/libcuda.so",
+ "cuda/lib/libcudart.so.9.0",
+ "cuda/lib/libcudart_static.a",
+ "cuda/lib/libcublas.so.9.0",
+ "cuda/lib/libcusolver.so.9.0",
+ "cuda/lib/libcurand.so.9.0",
+ "cuda/lib/libcufft.so.9.0",
+ "cuda/lib/libcudnn.so.7",
+ "cuda/lib/libcupti.so.9.0",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
+ """,
+)
+
+genrule(
+ name = "cudnn-include",
+ outs = [
+ "cuda/include/cudnn.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/include/cudnn.h" "$(@D)/cudnn.h"
+ """,
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
new file mode 100755
index 0000000000..5c6703aab4
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
@@ -0,0 +1,33 @@
+# Macros for building CUDA code.
+def if_cuda(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with CUDA.
+
+ Returns a select statement which evaluates to if_true if we're building
+ with CUDA enabled. Otherwise, the select statement evaluates to if_false.
+
+ """
+ return select({
+ "@local_config_cuda//cuda:using_nvcc": if_true,
+ "@local_config_cuda//cuda:using_clang": if_true,
+ "//conditions:default": if_false
+ })
+
+
+def cuda_default_copts():
+ """Default options for all CUDA compilations."""
+ return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + [])
+
+
+def cuda_is_configured():
+ """Returns true if CUDA was enabled during the configure process."""
+ return True
+
+def if_cuda_is_configured(x):
+ """Tests if the CUDA was enabled during the configure process.
+
+ Unlike if_cuda(), this does not require that we are building with
+ --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
+ """
+ if cuda_is_configured():
+ return x
+ return []
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/cuda/cuda_config.h b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/cuda/cuda_config.h
new file mode 100755
index 0000000000..5d0d3013a9
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/cuda/cuda_config.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef CUDA_CUDA_CONFIG_H_
+#define CUDA_CUDA_CONFIG_H_
+
+#define TF_CUDA_CAPABILITIES CudaVersion("3.0")
+
+#define TF_CUDA_VERSION "9.0"
+#define TF_CUDNN_VERSION "7"
+
+#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-9.0"
+
+#endif // CUDA_CUDA_CONFIG_H_
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
new file mode 100755
index 0000000000..a56b4513fb
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
@@ -0,0 +1,73 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_toolchain_suite(
+ name = "toolchain",
+ toolchains = {
+ "local|compiler": ":cc-compiler-local",
+ "darwin|compiler": ":cc-compiler-darwin",
+ "x64_windows|msvc-cl": ":cc-compiler-windows",
+ },
+)
+
+cc_toolchain(
+ name = "cc-compiler-local",
+ all_files = ":crosstool_wrapper_driver_is_not_gcc",
+ compiler_files = ":empty",
+ cpu = "local",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ # To support linker flags that need to go to the start of command line
+ # we need the toolchain to support parameter files. Parameter files are
+ # last on the command line and contain all shared libraries to link, so all
+ # regular options will be left of them.
+ supports_param_files = 1,
+)
+
+cc_toolchain(
+ name = "cc-compiler-darwin",
+ all_files = ":crosstool_wrapper_driver_is_not_gcc",
+ compiler_files = ":empty",
+ cpu = "darwin",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 0,
+)
+
+cc_toolchain(
+ name = "cc-compiler-windows",
+ all_files = ":windows_msvc_wrapper_files",
+ compiler_files = ":empty",
+ cpu = "x64_windows",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":windows_msvc_wrapper_files",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 1,
+)
+
+filegroup(
+ name = "empty",
+ srcs = [],
+)
+
+filegroup(
+ name = "crosstool_wrapper_driver_is_not_gcc",
+ srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
+)
+
+filegroup(
+ name = "windows_msvc_wrapper_files",
+ srcs = glob(["windows/msvc_*"]),
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/CROSSTOOL b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/CROSSTOOL
new file mode 100755
index 0000000000..a14eceacbb
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/CROSSTOOL
@@ -0,0 +1,1410 @@
+major_version: "local"
+minor_version: ""
+default_target_cpu: "same_as_host"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "piii"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "arm"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "darwin"
+ toolchain_identifier: "local_darwin"
+}
+default_toolchain {
+ cpu: "ppc"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "x64_windows"
+ toolchain_identifier: "local_windows"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "local"
+ target_cpu: "local"
+ target_system_name: "local"
+ toolchain_identifier: "local_linux"
+
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lstdc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ feature {
+ name: "alwayslink"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-Wl,-no-as-needed"
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-Wl,-z,relro,-z,now"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ flag: "-Wl,-z,relro,-z,now"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "build-id"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ # Stamp the binary with a unique identifier.
+ flag: "-Wl,--build-id=md5"
+ flag: "-Wl,--hash-style=gnu"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "alwayslink"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "build-id"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-Wl,--gc-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/ar" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+
+ cxx_builtin_include_directory: "/"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin"
+ }
+ }
+ }
+
+ feature {
+ name: "undefined-dynamic"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-undefined"
+ flag: "dynamic_lookup"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ implies: "undefined-dynamic"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+
+ cxx_builtin_include_directory: "/"
+}
+
+toolchain {
+ toolchain_identifier: "local_windows"
+ host_system_name: "local"
+ target_system_name: "local"
+
+ abi_version: "local"
+ abi_libc_version: "local"
+ target_cpu: "x64_windows"
+ compiler: "msvc-cl"
+ target_libc: "msvcrt"
+
+
+
+ tool_path {
+ name: "ar"
+ path: ""
+ }
+ tool_path {
+ name: "ml"
+ path: ""
+ }
+ tool_path {
+ name: "cpp"
+ path: ""
+ }
+ tool_path {
+ name: "gcc"
+ path: ""
+ }
+ tool_path {
+ name: "gcov"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "ld"
+ path: ""
+ }
+ tool_path {
+ name: "nm"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objcopy"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objdump"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "strip"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ supports_interface_shared_objects: true
+
+ # TODO(pcloudy): Review those flags below, they should be defined by cl.exe
+ compiler_flag: "/DCOMPILER_MSVC"
+
+ # Don't define min/max macros in windows.h.
+ compiler_flag: "/DNOMINMAX"
+
+ # Platform defines.
+ compiler_flag: "/D_WIN32_WINNT=0x0600"
+ # Turn off warning messages.
+ compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE"
+ compiler_flag: "/D_CRT_SECURE_NO_WARNINGS"
+ compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS"
+
+ # Useful options to have on for compilation.
+ # Increase the capacity of object files to 2^32 sections.
+ compiler_flag: "/bigobj"
+ # Allocate 500MB for precomputed headers.
+ compiler_flag: "/Zm500"
+ # Use unsigned char by default.
+ compiler_flag: "/J"
+ # Use function level linking.
+ compiler_flag: "/Gy"
+ # Use string pooling.
+ compiler_flag: "/GF"
+ # Catch C++ exceptions only and tell the compiler to assume that functions declared
+ # as extern "C" never throw a C++ exception.
+ compiler_flag: "/EHsc"
+
+ # Globally disabled warnings.
+ # Don't warn about elements of array being be default initialized.
+ compiler_flag: "/wd4351"
+ # Don't warn about no matching delete found.
+ compiler_flag: "/wd4291"
+ # Don't warn about diamond inheritance patterns.
+ compiler_flag: "/wd4250"
+ # Don't warn about insecure functions (e.g. non _s functions).
+ compiler_flag: "/wd4996"
+
+ linker_flag: "/MACHINE:X64"
+
+ feature {
+ name: "no_legacy_features"
+ }
+
+ # Suppress startup banner.
+ feature {
+ name: "nologo"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ flag_group {
+ flag: "/nologo"
+ }
+ }
+ }
+
+ feature {
+ name: 'has_configured_linker_path'
+ }
+
+ # This feature indicates strip is not supported, building stripped binary will just result a copy of orignial binary
+ feature {
+ name: 'no_stripping'
+ }
+
+ # This feature indicates this is a toolchain targeting Windows.
+ feature {
+ name: 'targets_windows'
+ implies: 'copy_dynamic_libraries_to_binary'
+ enabled: true
+ }
+
+ feature {
+ name: 'copy_dynamic_libraries_to_binary'
+ }
+
+ action_config {
+ config_name: 'assemble'
+ action_name: 'assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'preprocess-assemble'
+ action_name: 'preprocess-assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'c-compile'
+ action_name: 'c-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-compile'
+ action_name: 'c++-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-link-executable'
+ action_name: 'c++-link-executable'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ }
+
+ action_config {
+ config_name: 'c++-link-dynamic-library'
+ action_name: 'c++-link-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-nodeps-dynamic-library'
+ action_name: 'c++-link-nodeps-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-static-library'
+ action_name: 'c++-link-static-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'archiver_flags'
+ implies: 'input_param_flags'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ }
+
+ # TODO(b/65151735): Remove legacy_compile_flags feature when legacy fields are
+ # not used in this crosstool
+ feature {
+ name: 'legacy_compile_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'legacy_compile_flags'
+ flag: '%{legacy_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: "msvc_env"
+ env_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ env_entry {
+ key: "PATH"
+ value: ""
+ }
+ env_entry {
+ key: "INCLUDE"
+ value: ""
+ }
+ env_entry {
+ key: "LIB"
+ value: ""
+ }
+ env_entry {
+ key: "TMP"
+ value: ""
+ }
+ env_entry {
+ key: "TEMP"
+ value: ""
+ }
+ }
+ }
+
+ feature {
+ name: 'include_paths'
+ flag_set {
+ action: "assemble"
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ flag_group {
+ iterate_over: 'quote_include_paths'
+ flag: '/I%{quote_include_paths}'
+ }
+ flag_group {
+ iterate_over: 'include_paths'
+ flag: '/I%{include_paths}'
+ }
+ flag_group {
+ iterate_over: 'system_include_paths'
+ flag: '/I%{system_include_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: "preprocessor_defines"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-module-compile"
+ flag_group {
+ flag: "/D%{preprocessor_defines}"
+ iterate_over: "preprocessor_defines"
+ }
+ }
+ }
+
+ # Tell Bazel to parse the output of /showIncludes
+ feature {
+ name: 'parse_showincludes'
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-module-compile'
+ action: 'c++-header-parsing'
+ flag_group {
+ flag: "/showIncludes"
+ }
+ }
+ }
+
+
+ feature {
+ name: 'generate_pdb_file'
+ requires: {
+ feature: 'dbg'
+ }
+ requires: {
+ feature: 'fastbuild'
+ }
+ }
+
+ feature {
+ name: 'shared_flag'
+ flag_set {
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/DLL'
+ }
+ }
+ }
+
+ feature {
+ name: 'linkstamps'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ expand_if_all_available: 'linkstamp_paths'
+ flag_group {
+ iterate_over: 'linkstamp_paths'
+ flag: '%{linkstamp_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: 'output_execpath_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'archiver_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'input_param_flags'
+ flag_set {
+ expand_if_all_available: 'interface_library_output_path'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/IMPLIB:%{interface_library_output_path}"
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libopts'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'libopts'
+ flag: '%{libopts}'
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libraries_to_link'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ iterate_over: 'libraries_to_link'
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file_group'
+ }
+ iterate_over: 'libraries_to_link.object_files'
+ flag_group {
+ flag: '%{libraries_to_link.object_files}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'interface_library'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'static_library'
+ }
+ flag_group {
+ expand_if_false: 'libraries_to_link.is_whole_archive'
+ flag: '%{libraries_to_link.name}'
+ }
+ flag_group {
+ expand_if_true: 'libraries_to_link.is_whole_archive'
+ flag: '/WHOLEARCHIVE:%{libraries_to_link.name}'
+ }
+ }
+ }
+ }
+ }
+
+ # Since this feature is declared earlier in the CROSSTOOL than
+ # "user_link_flags", this feature will be applied prior to it anwyhere they
+ # are both implied. And since "user_link_flags" contains the linkopts from
+ # the build rule, this allows the user to override the /SUBSYSTEM in the BUILD
+ # file.
+ feature {
+ name: 'linker_subsystem_flag'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/SUBSYSTEM:CONSOLE'
+ }
+ }
+ }
+
+ # The "user_link_flags" contains user-defined linkopts (from build rules)
+ # so it should be defined after features that declare user-overridable flags.
+ # For example the "linker_subsystem_flag" defines a default "/SUBSYSTEM" flag
+ # but we want to let the user override it, therefore "link_flag_subsystem" is
+ # defined earlier in the CROSSTOOL file than "user_link_flags".
+ feature {
+ name: 'user_link_flags'
+ flag_set {
+ expand_if_all_available: 'user_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'user_link_flags'
+ flag: '%{user_link_flags}'
+ }
+ }
+ }
+ feature {
+ name: 'legacy_link_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'legacy_link_flags'
+ flag: '%{legacy_link_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'linker_param_file'
+ flag_set {
+ expand_if_all_available: 'linker_param_file'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '@%{linker_param_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'static_link_msvcrt'
+ }
+
+ feature {
+ name: 'static_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MT"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MD"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'static_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MTd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MDd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dbg'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FULL"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'fastbuild'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FASTLINK"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'opt'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/O2"
+ flag: "/DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: 'user_compile_flags'
+ flag_set {
+ expand_if_all_available: 'user_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'user_compile_flags'
+ flag: '%{user_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'sysroot'
+ flag_set {
+ expand_if_all_available: 'sysroot'
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'sysroot'
+ flag: '--sysroot=%{sysroot}'
+ }
+ }
+ }
+
+ feature {
+ name: 'unfiltered_compile_flags'
+ flag_set {
+ expand_if_all_available: 'unfiltered_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'unfiltered_compile_flags'
+ flag: '%{unfiltered_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_output_flags'
+ flag_set {
+ action: 'assemble'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ flag: '/Zi'
+ }
+ }
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_assembly_file'
+ flag: '/Fa%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_preprocess_file'
+ flag: '/P'
+ flag: '/Fi%{output_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_input_flags'
+ flag_set {
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'source_file'
+ flag: '/c'
+ flag: '%{source_file}'
+ }
+ }
+ }
+
+ feature {
+ name : 'def_file',
+ flag_set {
+ expand_if_all_available: 'def_file_path'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEF:%{def_file_path}"
+ # We can specify a different DLL name in DEF file, /ignore:4070 suppresses
+ # the warning message about DLL name doesn't match the default one.
+ # See https://msdn.microsoft.com/en-us/library/sfkk2fz7.aspx
+ flag: "/ignore:4070"
+ }
+ }
+ }
+
+ feature {
+ name: 'windows_export_all_symbols'
+ }
+
+ feature {
+ name: 'no_windows_export_all_symbols'
+ }
+
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/clang/bin/crosstool_wrapper_driver_is_not_gcc
new file mode 100755
index 0000000000..63893d3722
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/clang/bin/crosstool_wrapper_driver_is_not_gcc
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs.
+
+SYNOPSIS:
+ crosstool_wrapper_is_not_gcc [options passed in by cc_library()
+ or cc_binary() rule]
+
+DESCRIPTION:
+ This script is expected to be called by the cc_library() or cc_binary() bazel
+ rules. When the option "-x cuda" is present in the list of arguments passed
+ to this script, it invokes the nvcc CUDA compiler. Most arguments are passed
+ as is as a string to --compiler-options of nvcc. When "-x cuda" is not
+ present, this wrapper invokes hybrid_driver_is_not_gcc with the input
+ arguments as is.
+
+NOTES:
+ Changes to the contents of this file must be propagated from
+ //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to
+ //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+__author__ = 'keveman@google.com (Manjunath Kudlur)'
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('/usr/bin/gcc')
+GCC_HOST_COMPILER_PATH = ('/usr/bin/gcc')
+
+NVCC_PATH = '/usr/local/cuda-9.0/bin/nvcc'
+PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
+NVCC_VERSION = '9.0'
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from the argv list.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ option: The option whose value to extract, without the leading '-'.
+
+ Returns:
+ A list of values, either directly following the option,
+ (eg., -opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., -opt val1 -opt val2).
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-' + option, nargs='*', action='append')
+ args, _ = parser.parse_known_args(argv)
+ if not args or not vars(args)[option]:
+ return []
+ else:
+ return sum(vars(args)[option], [])
+
+
+def GetHostCompilerOptions(argv):
+ """Collect the -isystem, -iquote, and --sysroot option values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be used as the --compiler-options to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-isystem', nargs='*', action='append')
+ parser.add_argument('-iquote', nargs='*', action='append')
+ parser.add_argument('--sysroot', nargs=1)
+ parser.add_argument('-g', nargs='*', action='append')
+ parser.add_argument('-fno-canonical-system-headers', action='store_true')
+
+ args, _ = parser.parse_known_args(argv)
+
+ opts = ''
+
+ if args.isystem:
+ opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
+ if args.iquote:
+ opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
+ if args.g:
+ opts += ' -g' + ' -g'.join(sum(args.g, []))
+ if args.fno_canonical_system_headers:
+ opts += ' -fno-canonical-system-headers'
+ if args.sysroot:
+ opts += ' --sysroot ' + args.sysroot[0]
+
+ return opts
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be passed directly to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return ' '.join(['--'+a for a in options])
+ return ''
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ host_compiler_options = GetHostCompilerOptions(argv)
+ nvcc_compiler_options = GetNvccOptions(argv)
+ opt_option = GetOptionValue(argv, 'O')
+ m_options = GetOptionValue(argv, 'm')
+ m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
+ include_options = GetOptionValue(argv, 'I')
+ out_file = GetOptionValue(argv, 'o')
+ depfiles = GetOptionValue(argv, 'MF')
+ defines = GetOptionValue(argv, 'D')
+ defines = ''.join([' -D' + define for define in defines])
+ undefines = GetOptionValue(argv, 'U')
+ undefines = ''.join([' -U' + define for define in undefines])
+ std_options = GetOptionValue(argv, 'std')
+ # currently only c++11 is supported by Cuda 7.0 std argument
+ nvcc_allowed_std_options = ["c++11"]
+ std_options = ''.join([' -std=' + define
+ for define in std_options if define in nvcc_allowed_std_options])
+
+ # The list of source files get passed after the -c option. I don't know of
+ # any other reliable way to just get the list of source files to be compiled.
+ src_files = GetOptionValue(argv, 'c')
+
+ # Pass -w through from host to nvcc, but don't do anything fancier with
+ # warnings-related flags, since they're not necessarily the same across
+ # compilers.
+ warning_options = ' -w' if '-w' in argv else ''
+
+ if len(src_files) == 0:
+ return 1
+ if len(out_file) != 1:
+ return 1
+
+ opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
+ else ' -g -G')
+
+ includes = (' -I ' + ' -I '.join(include_options)
+ if len(include_options) > 0
+ else '')
+
+ # Unfortunately, there are other options that have -c prefix too.
+ # So allowing only those look like C/C++ files.
+ src_files = [f for f in src_files if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ srcs = ' '.join(src_files)
+ out = ' -o ' + out_file[0]
+
+ supported_cuda_compute_capabilities = [ "3.0" ]
+ nvccopts = '-D_FORCE_INLINES '
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+ capability, capability, capability)
+ nvccopts += ' ' + nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += std_options
+ nvccopts += m_options
+ nvccopts += warning_options
+
+ if depfiles:
+ # Generate the dependency file
+ depfile = depfiles[0]
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + '"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile)
+ if log: Log(cmd)
+ exit_status = os.system(cmd)
+ if exit_status != 0:
+ return exit_status
+
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + ' -fPIC"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + opt + includes + ' -c ' + srcs + out)
+
+ # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+ # Need to investigate and fix.
+ cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
+ if log: Log(cmd)
+ return os.system(cmd)
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.bat b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.bat
new file mode 100755
index 0000000000..e896e654fd
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.bat
@@ -0,0 +1,20 @@
+:: Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: http://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+:: =============================================================================
+
+:: Invoke msvc_wrapper_for_nvcc.py, which is located in the same directory.
+@echo OFF
+set arg0=%~0
+for %%F in ("%arg0%") do set DRIVER_BIN=%%~dpF
+"/usr/bin/python3" -B "%DRIVER_BIN%\msvc_wrapper_for_nvcc.py" %*
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.py b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.py
new file mode 100755
index 0000000000..859b3196d5
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/windows/msvc_wrapper_for_nvcc.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows.
+
+DESCRIPTION:
+ This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('/usr/bin/gcc')
+GCC_HOST_COMPILER_PATH = ('/usr/bin/gcc')
+
+NVCC_PATH = '/usr/local/cuda-9.0/bin/nvcc'
+NVCC_VERSION = '9.0'
+NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
+supported_cuda_compute_capabilities = [ "3.0" ]
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from options.
+
+ Args:
+ option: The option whose value to extract, without the leading '/'.
+
+ Returns:
+ 1. A list of values, either directly following the option,
+ (eg., /opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., /opt val1 /opt val2).
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser(prefix_chars='/')
+ parser.add_argument('/' + option, nargs='*', action='append')
+ args, leftover = parser.parse_known_args(argv)
+ if args and vars(args)[option]:
+ return (sum(vars(args)[option], []), leftover)
+ return ([], leftover)
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ 1. The string that can be passed directly to nvcc.
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, leftover = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return (['--' + a for a in options], leftover)
+ return ([], leftover)
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ src_files = [f for f in argv if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ if len(src_files) == 0:
+ raise Error('No source files found for cuda compilation.')
+
+ out_file = [ f for f in argv if f.startswith('/Fo') ]
+ if len(out_file) != 1:
+ raise Error('Please sepecify exactly one output file for cuda compilation.')
+ out = ['-o', out_file[0][len('/Fo'):]]
+
+ nvcc_compiler_options, argv = GetNvccOptions(argv)
+
+ opt_option, argv = GetOptionValue(argv, 'O')
+ opt = ['-g', '-G']
+ if (len(opt_option) > 0 and opt_option[0] != 'd'):
+ opt = ['-O2']
+
+ include_options, argv = GetOptionValue(argv, 'I')
+ includes = ["-I " + include for include in include_options]
+
+ defines, argv = GetOptionValue(argv, 'D')
+ defines = ['-D' + define for define in defines]
+
+ undefines, argv = GetOptionValue(argv, 'U')
+ undefines = ['-U' + define for define in undefines]
+
+ # The rest of the unrecongized options should be passed to host compiler
+ host_compiler_options = [option for option in argv if option not in (src_files + out_file)]
+
+ m_options = ["-m64"]
+
+ nvccopts = ['-D_FORCE_INLINES']
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
+ capability, capability, capability)]
+ nvccopts += nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += m_options
+ nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"']
+ nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files
+ # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP
+ # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check
+ # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver
+ # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists.
+ if os.path.isfile(NVCC_TEMP_DIR):
+ os.remove(NVCC_TEMP_DIR)
+ if not os.path.exists(NVCC_TEMP_DIR):
+ os.makedirs(NVCC_TEMP_DIR)
+ nvccopts += ['--keep', '--keep-dir', NVCC_TEMP_DIR]
+ cmd = [NVCC_PATH] + nvccopts
+ if log:
+ Log(cmd)
+ proc = subprocess.Popen(cmd,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ env=os.environ.copy(),
+ shell=True)
+ proc.wait()
+ return proc.returncode
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))
+ and not flag.startswith(('-nvcc_options'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/nccl2/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/nccl2/BUILD
new file mode 100755
index 0000000000..96ed60d3cf
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/nccl2/BUILD
@@ -0,0 +1,25 @@
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ srcs = ["libnccl.so.2"],
+ hdrs = ["nccl.h"],
+ include_prefix = "third_party/nccl",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+)
+
+genrule(
+ name = "nccl-files",
+ outs = [
+ "libnccl.so.2",
+ "nccl.h",
+ ],
+ cmd = """cp "/usr/include/nccl.h" "$(@D)/nccl.h" &&
+ cp "/usr/lib/libnccl.so.2" "$(@D)/libnccl.so.2" """,
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/nccl2/WORKSPACE b/third_party/toolchains/preconfig/ubuntu14.04/nccl2/WORKSPACE
new file mode 100644
index 0000000000..1e6662ac91
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/nccl2/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for nccl_configure rule
+workspace(name = "local_config_nccl")
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
new file mode 100755
index 0000000000..e021df9e1e
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
@@ -0,0 +1,176 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
+# See https://docs.python.org/3/extending/windows.html
+cc_import(
+ name = "python_lib",
+ interface_library = select({
+ ":windows": ":python_import_lib",
+ # A placeholder for Unix platforms which makes --no_build happy.
+ "//conditions:default": "not-existing.lib",
+ }),
+ system_provided = 1,
+)
+
+cc_library(
+ name = "python_headers",
+ hdrs = [":python_include"],
+ includes = ["python_include"],
+ deps = select({
+ ":windows": [":python_lib"],
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "numpy_headers",
+ hdrs = [":numpy_include"],
+ includes = ["numpy_include"],
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "python_include",
+ outs = [
+ "python_include/Python-ast.h",
+ "python_include/Python.h",
+ "python_include/abstract.h",
+ "python_include/accu.h",
+ "python_include/asdl.h",
+ "python_include/ast.h",
+ "python_include/bitset.h",
+ "python_include/bltinmodule.h",
+ "python_include/boolobject.h",
+ "python_include/bytearrayobject.h",
+ "python_include/bytes_methods.h",
+ "python_include/bytesobject.h",
+ "python_include/cellobject.h",
+ "python_include/ceval.h",
+ "python_include/classobject.h",
+ "python_include/code.h",
+ "python_include/codecs.h",
+ "python_include/compile.h",
+ "python_include/complexobject.h",
+ "python_include/datetime.h",
+ "python_include/descrobject.h",
+ "python_include/dictobject.h",
+ "python_include/dtoa.h",
+ "python_include/dynamic_annotations.h",
+ "python_include/enumobject.h",
+ "python_include/errcode.h",
+ "python_include/eval.h",
+ "python_include/fileobject.h",
+ "python_include/fileutils.h",
+ "python_include/floatobject.h",
+ "python_include/frameobject.h",
+ "python_include/funcobject.h",
+ "python_include/genobject.h",
+ "python_include/graminit.h",
+ "python_include/grammar.h",
+ "python_include/import.h",
+ "python_include/intrcheck.h",
+ "python_include/iterobject.h",
+ "python_include/listobject.h",
+ "python_include/longintrepr.h",
+ "python_include/longobject.h",
+ "python_include/marshal.h",
+ "python_include/memoryobject.h",
+ "python_include/metagrammar.h",
+ "python_include/methodobject.h",
+ "python_include/modsupport.h",
+ "python_include/moduleobject.h",
+ "python_include/namespaceobject.h",
+ "python_include/node.h",
+ "python_include/object.h",
+ "python_include/objimpl.h",
+ "python_include/opcode.h",
+ "python_include/osdefs.h",
+ "python_include/parsetok.h",
+ "python_include/patchlevel.h",
+ "python_include/pgen.h",
+ "python_include/pgenheaders.h",
+ "python_include/py_curses.h",
+ "python_include/pyarena.h",
+ "python_include/pyatomic.h",
+ "python_include/pycapsule.h",
+ "python_include/pyconfig.h",
+ "python_include/pyctype.h",
+ "python_include/pydebug.h",
+ "python_include/pyerrors.h",
+ "python_include/pyexpat.h",
+ "python_include/pyfpe.h",
+ "python_include/pygetopt.h",
+ "python_include/pyhash.h",
+ "python_include/pymacconfig.h",
+ "python_include/pymacro.h",
+ "python_include/pymath.h",
+ "python_include/pymem.h",
+ "python_include/pyport.h",
+ "python_include/pystate.h",
+ "python_include/pystrcmp.h",
+ "python_include/pystrtod.h",
+ "python_include/pythonrun.h",
+ "python_include/pythread.h",
+ "python_include/pytime.h",
+ "python_include/rangeobject.h",
+ "python_include/setobject.h",
+ "python_include/sliceobject.h",
+ "python_include/structmember.h",
+ "python_include/structseq.h",
+ "python_include/symtable.h",
+ "python_include/sysmodule.h",
+ "python_include/token.h",
+ "python_include/traceback.h",
+ "python_include/tupleobject.h",
+ "python_include/typeslots.h",
+ "python_include/ucnhash.h",
+ "python_include/unicodeobject.h",
+ "python_include/warnings.h",
+ "python_include/weakrefobject.h",
+ ],
+ cmd = """
+cp "/usr/include/python3.4m/Python-ast.h" "$(@D)/python_include/Python-ast.h" && cp "/usr/include/python3.4m/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python3.4m/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python3.4m/accu.h" "$(@D)/python_include/accu.h" && cp "/usr/include/python3.4m/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python3.4m/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python3.4m/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python3.4m/bltinmodule.h" "$(@D)/python_include/bltinmodule.h" && cp "/usr/include/python3.4m/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python3.4m/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python3.4m/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python3.4m/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python3.4m/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python3.4m/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python3.4m/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python3.4m/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python3.4m/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python3.4m/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python3.4m/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python3.4m/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python3.4m/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python3.4m/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python3.4m/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python3.4m/dynamic_annotations.h" "$(@D)/python_include/dynamic_annotations.h" && cp "/usr/include/python3.4m/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python3.4m/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python3.4m/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python3.4m/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python3.4m/fileutils.h" "$(@D)/python_include/fileutils.h" && cp "/usr/include/python3.4m/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python3.4m/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python3.4m/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python3.4m/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python3.4m/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python3.4m/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python3.4m/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python3.4m/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python3.4m/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python3.4m/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python3.4m/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python3.4m/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python3.4m/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python3.4m/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python3.4m/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python3.4m/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python3.4m/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python3.4m/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python3.4m/namespaceobject.h" "$(@D)/python_include/namespaceobject.h" && cp "/usr/include/python3.4m/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python3.4m/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python3.4m/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python3.4m/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python3.4m/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python3.4m/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python3.4m/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python3.4m/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python3.4m/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python3.4m/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python3.4m/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python3.4m/pyatomic.h" "$(@D)/python_include/pyatomic.h" && cp "/usr/include/python3.4m/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python3.4m/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python3.4m/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python3.4m/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python3.4m/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python3.4m/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python3.4m/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python3.4m/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python3.4m/pyhash.h" "$(@D)/python_include/pyhash.h" && cp "/usr/include/python3.4m/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python3.4m/pymacro.h" "$(@D)/python_include/pymacro.h" && cp "/usr/include/python3.4m/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python3.4m/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python3.4m/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python3.4m/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python3.4m/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python3.4m/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python3.4m/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python3.4m/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python3.4m/pytime.h" "$(@D)/python_include/pytime.h" && cp "/usr/include/python3.4m/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python3.4m/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python3.4m/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python3.4m/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python3.4m/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python3.4m/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python3.4m/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python3.4m/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python3.4m/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python3.4m/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python3.4m/typeslots.h" "$(@D)/python_include/typeslots.h" && cp "/usr/include/python3.4m/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python3.4m/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python3.4m/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python3.4m/weakrefobject.h" "$(@D)/python_include/weakrefobject.h"
+ """,
+)
+
+genrule(
+ name = "numpy_include",
+ outs = [
+ "numpy_include/numpy/__multiarray_api.h",
+ "numpy_include/numpy/__ufunc_api.h",
+ "numpy_include/numpy/_neighborhood_iterator_imp.h",
+ "numpy_include/numpy/_numpyconfig.h",
+ "numpy_include/numpy/arrayobject.h",
+ "numpy_include/numpy/arrayscalars.h",
+ "numpy_include/numpy/halffloat.h",
+ "numpy_include/numpy/multiarray_api.txt",
+ "numpy_include/numpy/ndarrayobject.h",
+ "numpy_include/numpy/ndarraytypes.h",
+ "numpy_include/numpy/noprefix.h",
+ "numpy_include/numpy/npy_1_7_deprecated_api.h",
+ "numpy_include/numpy/npy_3kcompat.h",
+ "numpy_include/numpy/npy_common.h",
+ "numpy_include/numpy/npy_cpu.h",
+ "numpy_include/numpy/npy_endian.h",
+ "numpy_include/numpy/npy_interrupt.h",
+ "numpy_include/numpy/npy_math.h",
+ "numpy_include/numpy/npy_no_deprecated_api.h",
+ "numpy_include/numpy/npy_os.h",
+ "numpy_include/numpy/numpyconfig.h",
+ "numpy_include/numpy/old_defines.h",
+ "numpy_include/numpy/oldnumeric.h",
+ "numpy_include/numpy/ufunc_api.txt",
+ "numpy_include/numpy/ufuncobject.h",
+ "numpy_include/numpy/utils.h",
+ ],
+ cmd = """
+cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/local/lib/python3.4/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h"
+ """,
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/py3/WORKSPACE b/third_party/toolchains/preconfig/ubuntu14.04/py3/WORKSPACE
new file mode 100644
index 0000000000..1d298fefa3
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/py3/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for python_configure rule
+workspace(name = "local_config_python")