aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD16
-rw-r--r--tensorflow/c/c_api.cc14
-rw-r--r--tensorflow/c/c_api.h5
-rw-r--r--tensorflow/c/c_api_experimental.cc12
-rw-r--r--tensorflow/c/c_api_experimental.h6
-rw-r--r--tensorflow/c/c_api_function.cc4
-rw-r--r--tensorflow/c/c_api_function_test.cc1
-rw-r--r--tensorflow/c/c_api_test.cc2
-rw-r--r--tensorflow/c/eager/c_api.cc4
-rw-r--r--tensorflow/c/eager/c_api.h3
-rw-r--r--tensorflow/c/eager/c_api_test.cc54
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc6
-rw-r--r--tensorflow/cc/saved_model/loader.cc56
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc4
-rw-r--r--tensorflow/compiler/jit/xla_device.h7
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py136
-rw-r--r--tensorflow/compiler/tf2xla/BUILD6
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.cc53
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc37
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bucketize_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cross_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc152
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/l2loss_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/listdiff_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/lrn_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pack_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc29
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc63
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h49
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD18
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h6
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_gpu_backend.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.h2
-rw-r--r--tensorflow/compiler/xla/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/BUILD45
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD49
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/constants_test.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/prng.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/prng.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting.cc46
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting.h (renamed from tensorflow/compiler/xla/service/pool_test.cc)31
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting_test.cc60
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc3
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc7
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc (renamed from tensorflow/compiler/xla/client/xla_client/xla_builder.cc)2
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h2241
-rw-r--r--tensorflow/compiler/xla/client/xla_builder_test.cc (renamed from tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc)2
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD37
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h2222
-rw-r--r--tensorflow/compiler/xla/python/BUILD2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/reference_util.cc2
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD2
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/BUILD30
-rw-r--r--tensorflow/compiler/xla/service/backend.cc17
-rw-r--r--tensorflow/compiler/xla/service/backend.h14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc65
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc69
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc25
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc17
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc16
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc44
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc9
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD42
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc23
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc79
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc62
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h21
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc58
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h12
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc25
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc162
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc233
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h45
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc164
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.h15
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc59
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h34
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc40
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h2
-rw-r--r--tensorflow/compiler/xla/service/pool.h84
-rw-r--r--tensorflow/compiler/xla/service/service.cc11
-rw-r--r--tensorflow/compiler/xla/service/service_executable_run_options.h7
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc56
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.h64
-rw-r--r--tensorflow/compiler/xla/service/stream_pool_test.cc136
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc16
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/BUILD148
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/binop_scaling_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/bitcast_convert_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h2
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/deep_graph_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/fmax_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/hlo_metadata_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc5
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/log_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/query_inferred_shape_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/transpose_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc5
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py6
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements_test.py38
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements.py4
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements_test.py32
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers.py3
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb2
-rw-r--r--tensorflow/contrib/autograph/impl/api.py12
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py14
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py18
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py12
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py381
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py356
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py10
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py265
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py70
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py129
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py30
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py80
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/training_ops.cc17
-rw-r--r--tensorflow/contrib/boosted_trees/ops/training_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py84
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py11
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py6
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt1
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py4
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py13
-rw-r--r--tensorflow/contrib/distribute/python/BUILD46
-rw-r--r--tensorflow/contrib/distribute/python/checkpoint_utils_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py39
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py297
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py38
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py355
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py455
-rw-r--r--tensorflow/contrib/distribute/python/values.py16
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py20
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/README.md11
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/README.md70
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py111
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py106
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py20
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator.py22
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py21
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py126
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py26
-rw-r--r--tensorflow/contrib/eager/python/tfe.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py68
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py14
-rw-r--r--tensorflow/contrib/gan/BUILD5
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py86
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py114
-rw-r--r--tensorflow/contrib/gan/python/train.py128
-rw-r--r--tensorflow/contrib/gan/python/train_test.py21
-rw-r--r--tensorflow/contrib/lite/allocation.cc4
-rw-r--r--tensorflow/contrib/lite/build_def.bzl37
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h4
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/context.h7
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD37
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.cc4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.cc3
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc289
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h34
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel_test.cc351
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD8
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore13
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity242
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta7
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytesbin0 -> 476 bytes
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta7
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs70
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta11
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs145
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta11
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset17
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset6
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset29
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset7
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset21
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset61
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset295
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset91
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset37
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset641
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt1
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset191
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset43
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset9
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset34
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md24
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json4
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md13
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD29
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h48
-rw-r--r--tensorflow/contrib/lite/kernels/internal/spectrogram.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h8
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/logical_test.cc87
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc199
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot_test.cc182
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc17
-rw-r--r--tensorflow/contrib/lite/kernels/reshape_test.cc37
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc32
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc142
-rw-r--r--tensorflow/contrib/lite/model.cc8
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/profiling/time.cc18
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h141
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.cc1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py124
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc2
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc59
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h34
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc21
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc10
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc7
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc13
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2_test.py40
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py2
-rw-r--r--tensorflow/contrib/saved_model/BUILD29
-rw-r--r--tensorflow/contrib/saved_model/__init__.py3
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/__init__.py1
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py108
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py201
-rw-r--r--tensorflow/contrib/tensorrt/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc138
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h6
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py72
-rw-r--r--tensorflow/contrib/timeseries/examples/multivariate.py4
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto26
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto114
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py132
-rw-r--r--tensorflow/core/BUILD8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt78
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc247
-rw-r--r--tensorflow/core/common_runtime/broadcaster.h17
-rw-r--r--tensorflow/core/common_runtime/broadcaster_test.cc168
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc193
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h8
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc129
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc14
-rw-r--r--tensorflow/core/common_runtime/function_test.cc5
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h1
-rw-r--r--tensorflow/core/common_runtime/placer.cc59
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/common_runtime/session_ref.h86
-rw-r--r--tensorflow/core/framework/node_def_util.cc13
-rw-r--r--tensorflow/core/framework/node_def_util.h6
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc52
-rw-r--r--tensorflow/core/framework/op_compatibility_test.cc65
-rw-r--r--tensorflow/core/framework/op_kernel.cc6
-rw-r--r--tensorflow/core/framework/op_kernel.h43
-rw-r--r--tensorflow/core/graph/control_flow.cc37
-rw-r--r--tensorflow/core/graph/control_flow_test.cc17
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h1
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc81
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc31
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc1
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc1
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc5
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/cwise_op_tan.cc3
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc24
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc9
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc9
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc17
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h1
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h1
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc6
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc111
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc55
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc68
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc150
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc4
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc9
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h14
-rw-r--r--tensorflow/core/lib/core/errors.h20
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt38
-rw-r--r--tensorflow/core/ops/image_ops.cc39
-rw-r--r--tensorflow/core/ops/math_ops.cc1
-rw-r--r--tensorflow/core/ops/ops.pbtxt38
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl6
-rw-r--r--tensorflow/core/platform/env_test.cc2
-rw-r--r--tensorflow/core/protobuf/config.proto4
-rw-r--r--tensorflow/core/util/equal_graph_def_test.cc6
-rw-r--r--tensorflow/core/util/mkl_util.h22
-rw-r--r--tensorflow/docs_src/deploy/distributed.md2
-rw-r--r--tensorflow/docs_src/guide/using_gpu.md2
-rw-r--r--tensorflow/docs_src/performance/xla/broadcasting.md2
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md84
-rw-r--r--tensorflow/examples/saved_model/saved_model_half_plus_two.py116
-rw-r--r--tensorflow/go/op/wrappers.go280
-rw-r--r--tensorflow/java/maven/README.md22
-rw-r--r--tensorflow/java/maven/run_inside_container.sh68
-rw-r--r--tensorflow/java/maven/tensorflow-android/update.py17
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java67
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java42
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc10
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h6
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java7
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java1
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java64
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session.py16
-rw-r--r--tensorflow/python/client/tf_session.i1
-rw-r--r--tensorflow/python/client/tf_session_helper.cc14
-rw-r--r--tensorflow/python/client/tf_session_helper.h3
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py11
-rw-r--r--tensorflow/python/distribute/BUILD43
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py361
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py291
-rw-r--r--tensorflow/python/eager/function.py13
-rw-r--r--tensorflow/python/eager/graph_callable.py4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc4
-rw-r--r--tensorflow/python/estimator/BUILD3
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py45
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py3
-rw-r--r--tensorflow/python/estimator/estimator.py80
-rw-r--r--tensorflow/python/estimator/estimator_test.py13
-rw-r--r--tensorflow/python/estimator/keras.py47
-rw-r--r--tensorflow/python/estimator/keras_test.py6
-rw-r--r--tensorflow/python/estimator/model_fn.py58
-rw-r--r--tensorflow/python/framework/error_interpolation.py53
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py105
-rw-r--r--tensorflow/python/framework/function.py26
-rw-r--r--tensorflow/python/framework/ops.py181
-rw-r--r--tensorflow/python/framework/ops_test.py67
-rw-r--r--tensorflow/python/framework/test_util_test.py2
-rw-r--r--tensorflow/python/keras/callbacks.py73
-rw-r--r--tensorflow/python/keras/callbacks_test.py78
-rw-r--r--tensorflow/python/keras/engine/base_layer.py64
-rw-r--r--tensorflow/python/keras/engine/network.py109
-rw-r--r--tensorflow/python/keras/engine/saving_test.py41
-rw-r--r--tensorflow/python/keras/engine/topology_test.py99
-rw-r--r--tensorflow/python/keras/engine/training.py200
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py18
-rw-r--r--tensorflow/python/keras/engine/training_eager.py44
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py278
-rw-r--r--tensorflow/python/keras/engine/training_test.py78
-rw-r--r--tensorflow/python/keras/engine/training_utils.py63
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py3
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py6
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_jpeg_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/random/random_ops_test.py96
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py11
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py10
-rw-r--r--tensorflow/python/ops/image_ops_impl.py59
-rw-r--r--tensorflow/python/ops/rnn.py12
-rw-r--r--tensorflow/python/saved_model/constants.py6
-rw-r--r--tensorflow/python/summary/writer/writer.py2
-rw-r--r--tensorflow/python/tools/freeze_graph.py32
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py67
-rw-r--r--tensorflow/python/training/checkpoint_utils.py36
-rw-r--r--tensorflow/python/training/checkpoint_utils_test.py4
-rw-r--r--tensorflow/python/training/checkpointable/base.py2
-rw-r--r--tensorflow/python/training/checkpointable/util.py37
-rw-r--r--tensorflow/python/training/saver_test.py31
-rw-r--r--tensorflow/python/util/function_utils.py35
-rw-r--r--tensorflow/python/util/function_utils_test.py78
-rw-r--r--tensorflow/python/util/nest.py56
-rw-r--r--tensorflow/python/util/nest_test.py29
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc117
-rw-r--r--tensorflow/stream_executor/module_spec.h1
-rw-r--r--tensorflow/stream_executor/stream.cc11
-rw-r--r--tensorflow/stream_executor/stream.h2
-rw-r--r--tensorflow/stream_executor/stream_test.cc139
-rw-r--r--tensorflow/tensorflow.bzl63
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt4
-rwxr-xr-xtensorflow/tools/ci_build/ci_build.sh3
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-cpu-mkl83
-rw-r--r--tensorflow/workspace.bzl8
613 files changed, 17766 insertions, 6316 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 388ca3f293..60db234c9c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -381,6 +381,14 @@ config_setting(
},
)
+# Setting to use when loading kernels dynamically
+config_setting(
+ name = "dynamic_loaded_kernels",
+ define_values = {
+ "dynamic_loaded_kernels": "true",
+ },
+)
+
config_setting(
name = "using_cuda_nvcc",
define_values = {
@@ -408,14 +416,6 @@ config_setting(
visibility = ["//visibility:public"],
)
-# TODO(laigd): consider removing this option and make TensorRT enabled
-# automatically when CUDA is enabled.
-config_setting(
- name = "with_tensorrt_support",
- values = {"define": "with_tensorrt_support=true"},
- visibility = ["//visibility:public"],
-)
-
package_group(
name = "internal",
packages = [
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index bcecbb0bc6..19ccb6e71d 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -52,8 +52,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -2394,8 +2394,8 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
}
void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
- int ny, TF_Output* x, int nx, TF_Output* dx,
- TF_Status* status, TF_Output* dy) {
+ int ny, TF_Output* x, int nx, TF_Output* dx,
+ TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Adding gradients is not supported in Android. File a bug at "
@@ -2420,13 +2420,13 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
prefix_cmp = string(prefix) + "/";
// The operation should fail if the provided name prefix has already been
// used in this graph
- for (const auto& pair: g->name_map) {
+ for (const auto& pair : g->name_map) {
const string& name = pair.first;
if (name.compare(prefix) == 0 ||
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
- status->status = InvalidArgument("prefix [", prefix,
- "] conflicts with existing node in the graph named [",
- name, "]");
+ status->status = InvalidArgument(
+ "prefix [", prefix,
+ "] conflicts with existing node in the graph named [", name, "]");
return;
}
}
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 0a9fa9ddbc..850f6ecd63 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1270,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
int noutputs, const TF_Output* outputs, const char* const* output_names,
const TF_FunctionOptions* opts, const char* description, TF_Status* status);
+// Returns the name of the graph function.
+// The return value points to memory that is only usable until the next
+// mutation to *func.
+TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func);
+
// Write out a serialized representation of `func` (as a FunctionDef protocol
// message) to `output_func_def` (allocated by TF_NewBuffer()).
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 170046c802..69b3ffe2a1 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -84,6 +84,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
return ret;
}
+TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
+ tensorflow::RunOptions options;
+ if (enable_full_trace) {
+ options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
+ } else {
+ options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
+ }
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(options, ret));
+ return ret;
+}
+
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 2d81c01e0d..6617c5a572 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -70,6 +70,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
unsigned char enable_xla_compilation,
unsigned char gpu_memory_allow_growth);
+// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level
+// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE
+// otherwise.
+TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions(
+ unsigned char enable_full_trace);
+
// Returns the graph content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 384e6c8cb9..a2c5a42c11 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
return tf_function;
}
+const char* TF_FunctionName(TF_Function* func) {
+ return func->fdef.signature().name().c_str();
+}
+
void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
const TF_Function* grad, TF_Status* status) {
if (func == nullptr) {
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index f7ca219c89..bb9433ce25 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test {
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(func_, nullptr);
+ ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index d8d2533c60..aa2a537f03 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1709,7 +1709,7 @@ class CApiGradientsTest : public ::testing::Test {
}
void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
- const char* prefix2 = nullptr) {
+ const char* prefix2 = nullptr) {
TF_Output inputs[2];
TF_Output outputs[1];
TF_Output grad_outputs[2];
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 6c510536d6..7321b4b791 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -288,7 +288,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->async, std::move(device_mgr), r);
}
-void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
+void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
@@ -336,7 +336,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
- DCHECK(h);
+ if (h == nullptr) return;
if (h->handle) {
h->handle->Unref();
}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index fdbd5374b2..ea019a5711 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -102,8 +102,7 @@ typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
const TFE_ContextOptions* opts, TF_Status* status);
-TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx,
- TF_Status* status);
+TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 3504a8b5e7..0bdea70fe6 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -49,7 +49,7 @@ void BM_InitOp(int iters) {
}
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) {
tensorflow::testing::StopTiming();
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -95,7 +95,7 @@ TEST(CAPI, Context) {
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const int num_devices = TF_DeviceListCount(devices);
@@ -195,7 +195,7 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@@ -281,7 +281,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@@ -380,8 +380,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenDevices) {
@@ -418,7 +417,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
TFE_DeleteTensorHandle(hcopy);
TFE_DeleteTensorHandle(hcpu);
if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
- TFE_DeleteContext(ctx, status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
@@ -451,7 +450,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
+ TFE_DeleteContext(ctx);
return;
}
const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
@@ -484,8 +483,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
@@ -533,8 +531,7 @@ void TensorHandleSilentCopy(bool async) {
TFE_DeleteTensorHandle(hcpu);
TFE_ContextAsyncWait(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
@@ -580,8 +577,7 @@ void TensorHandleSilentCopyLocal(bool async) {
TFE_DeleteTensorHandle(hcpu);
TFE_ContextAsyncWait(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
- TFE_DeleteContext(ctx, status.get());
- EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
@@ -614,7 +610,7 @@ void SetAndGetOpDevices(bool async) {
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -640,7 +636,7 @@ void Execute_MatMul_CPU(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -712,7 +708,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TFE_DeleteTensorHandle(m1);
TFE_DeleteTensorHandle(m2);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
@@ -743,7 +739,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
if (retvals[0] != nullptr) {
TFE_DeleteTensorHandle(retvals[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
@@ -781,7 +777,7 @@ TEST(CAPI, Execute_Min_CPU) {
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -823,7 +819,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
@@ -862,7 +858,7 @@ void Execute_Min_XLA_CPU(bool async) {
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
@@ -898,7 +894,7 @@ void ExecuteWithTracing(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -974,7 +970,7 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1044,7 +1040,7 @@ TEST(CAPI, Function_ident_XLA_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1120,7 +1116,7 @@ void FunctionDefAndExecute(bool async) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1161,7 +1157,7 @@ void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval[0]);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1249,7 +1245,7 @@ TEST(CAPI, Variables) {
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteTensorHandle(value_handle);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
@@ -1288,7 +1284,7 @@ void BM_ReadVariable(int iters) {
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle);
- TFE_DeleteContext(ctx, status);
+ TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index fd7b6fe662..1c9bdff5e1 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -475,11 +475,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
- // TODO(kbsriram)
- // Enable when tan kernel supports complex inputs
- if (false) {
- TestCWiseGrad<complex64, complex64>(TAN, x_fn);
- }
+ TestCWiseGrad<complex64, complex64>(TAN, x_fn);
}
TEST_F(CWiseUnaryGradTest, Atan) {
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index d47b025743..98be66a6ad 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -74,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir,
}
}
+// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
+// leaving behind non-GC'ed state.
+//
+// Detailed motivation behind this approach, from ashankar@:
+//
+// Each call to Session::Run() that identifies a new subgraph (based on feeds
+// and fetches) creates some datastructures that live as long as the session
+// (the partitioned graph, associated executors etc.).
+//
+// A pathological case of this would be if say the initialization op
+// (main_op/legacy_init_op) involves the use of a large constant. Then we
+// allocate memory for that large constant that will just stick around till the
+// session dies. With this Callable mechanism, that memory will be released
+// right after ReleaseCallable returns.
+//
+// However, the resource manager state remains.
+Status RunOnce(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata,
+ Session* session) {
+ CallableOptions callable_options;
+ std::vector<Tensor> feed_tensors;
+ *callable_options.mutable_run_options() = run_options;
+ for (const auto& input : inputs) {
+ const string& name = input.first;
+ const Tensor& tensor = input.second;
+ callable_options.add_feed(name);
+ feed_tensors.push_back(tensor);
+ }
+ for (const string& output_tensor_name : output_tensor_names) {
+ callable_options.add_fetch(output_tensor_name);
+ }
+ for (const string& target_node_name : target_node_names) {
+ callable_options.add_target(target_node_name);
+ }
+
+ Session::CallableHandle callable_handle;
+ TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
+ const Status run_status = session->RunCallable(callable_handle, feed_tensors,
+ outputs, run_metadata);
+ // Be sure to call ReleaseCallable() regardless of the outcome of
+ // RunCallable().
+ session->ReleaseCallable(callable_handle).IgnoreError();
+ return run_status;
+}
+
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
const auto& collection_def_map = meta_graph_def.collection_def();
if (collection_def_map.find(kSavedModelMainOpKey) !=
@@ -100,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return session->Run(run_options, inputs, {}, {main_op_name.ToString()},
- nullptr /* outputs */, &run_metadata);
+ return RunOnce(run_options, inputs, {}, {main_op_name.ToString()},
+ nullptr /* outputs */, &run_metadata, session);
}
return Status::OK();
}
@@ -138,8 +186,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
- nullptr /* outputs */, &run_metadata);
+ return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()},
+ nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 9174a67cc6..e34347b9d4 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -166,6 +166,7 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index c5d0e4f8fb..b313d48011 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -153,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
XlaCompiler::Options options;
options.client = client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 54a41a4daa..08c357c879 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable(
argument_layouts[i] = &result.xla_input_shapes[i];
}
xla::ExecutableBuildOptions build_options;
- build_options.set_device_ordinal(client_->default_device_ordinal());
+ build_options.set_device_ordinal(options.device_ordinal != -1
+ ? options.device_ordinal
+ : client_->default_device_ordinal());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index fccdb14368..4a5942fbd7 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -153,17 +154,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::Backend::StreamPtr stream_;
+ xla::StreamPool::Ptr stream_;
// If true, only stream_ is valid and all computation and transfers use
// stream_. If false, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::Backend::StreamPtr host_to_device_stream_;
+ xla::StreamPool::Ptr host_to_device_stream_;
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::Backend::StreamPtr device_to_host_stream_;
+ xla::StreamPool::Ptr device_to_host_stream_;
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 8b01ef96db..bf986ade06 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -26,6 +26,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
+from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase):
large_tolerance=True)
+class NonMaxSuppressionTest(xla_test.XLATestCase):
+
+ def testNMS128From1024(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ num_boxes = 1024
+ boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
+ scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4")
+
+ max_output_size = 128
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+
+ def testNMS3From6Boxes(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ # Three boxes are selected based on IOU.
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 3)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
+
+ def testNMS3Then2WithScoreThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 2)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 881624fff8..338943201b 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -140,14 +140,12 @@ cc_library(
"xla_op_registry.cc",
"xla_resource.cc",
"xla_cpu_backend.cc",
- "legacy_flags/backend_registration_flags.cc",
] + if_cuda_is_configured([
"xla_gpu_backend.cc",
]),
hdrs = [
"const_analysis.h",
"graph_compiler.h",
- "legacy_flags/backend_registration_flags.h",
"xla_compilation_device.h",
"xla_compiler.h",
"xla_context.h",
@@ -173,16 +171,14 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:numeric",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
- "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 03603ee9ba..24616c01c7 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -33,7 +33,7 @@ struct NameCounts {
std::unordered_map<string, int> counts;
};
-string MakeUniquePath(string name) {
+string MakeUniqueFilename(string name) {
static NameCounts& instance = *new NameCounts;
// Remove illegal characters from `name`.
@@ -50,26 +50,41 @@ string MakeUniquePath(string name) {
count = instance.counts[name]++;
}
- legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags();
- string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name);
+ string filename = name;
if (count > 0) {
- strings::StrAppend(&path, "_", count);
+ strings::StrAppend(&filename, "_", count);
}
- strings::StrAppend(&path, ".pbtxt");
- return path;
+ strings::StrAppend(&filename, ".pbtxt");
+ return filename;
+}
+
+string WriteTextProtoToUniqueFile(
+ Env* env, const string& name, const char* proto_type,
+ const ::tensorflow::protobuf::Message& proto) {
+ const string& dirname =
+ legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix;
+ Status status = env->RecursivelyCreateDir(dirname);
+ if (!status.ok()) {
+ LOG(WARNING) << "Failed to create " << dirname << " for dumping "
+ << proto_type << ": " << status;
+ return "(unavailable)";
+ }
+ string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name));
+ status = WriteTextProto(Env::Default(), filepath, proto);
+ if (!status.ok()) {
+ LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
+ << " : " << status;
+ return "(unavailable)";
+ }
+ LOG(INFO) << "Dumped " << proto_type << " to " << filepath;
+ return filepath;
}
} // anonymous namespace
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) {
- string path = MakeUniquePath(name);
- Status status = WriteTextProto(Env::Default(), path, graph_def);
- if (!status.ok()) {
- VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status;
- path.clear();
- path = "(unavailable)";
- }
- return path;
+ return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef",
+ graph_def);
}
string DumpGraphToFile(const string& name, Graph const& graph,
@@ -83,15 +98,7 @@ string DumpGraphToFile(const string& name, Graph const& graph,
}
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) {
- string path = MakeUniquePath(name);
- Status status = WriteTextProto(Env::Default(), path, fdef);
- if (!status.ok()) {
- VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : "
- << status;
- path.clear();
- path = "(unavailable)";
- }
- return path;
+ return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef);
}
} // namespace dump_graph
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 6cc95149a1..0904778f97 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -177,8 +177,8 @@ Status CheckNoCycleContains(const Node* node, const int num_nodes) {
visited[current_node->id()] = true;
for (const Edge* out : current_node->out_edges()) {
if (out->dst() == node) {
- return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(",
- node->def().op(), ") feeds into itself.");
+ return errors::Internal("Detected a cycle: ", FormatNodeForError(*node),
+ "(", node->def().op(), ") feeds into itself.");
} else if (!visited[out->dst()->id()]) {
ready.push_back(out->dst());
}
@@ -324,7 +324,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef,
if (library->Find(node.op())) {
continue;
}
- // The function refered by 'SymbolicGradient' node is specified in its
+ // The function referred by 'SymbolicGradient' node is specified in its
// attribute 'f'.
if (node.op() == FunctionLibraryDefinition::kGradientOp) {
const AttrValue* attr =
@@ -437,22 +437,24 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
if (enter_merge != nullptr) {
- return errors::Internal(
- "Enter node for loop-varying argument ", arg.enter->name(),
- " has multiple successors: ", enter_merge->dst()->name(), " and ",
- e->dst()->name());
+ return errors::Internal("Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.enter),
+ " has multiple successors: ",
+ FormatNodeForError(*enter_merge->dst()),
+ " and ", FormatNodeForError(*e->dst()));
}
enter_merge = e;
}
if (enter_merge == nullptr) {
return errors::Internal("Enter node for loop-varying argument ",
- arg.enter->name(), " has zero successors");
+ FormatNodeForError(*arg.enter),
+ " has zero successors");
}
arg.merge = enter_merge->dst();
if (!IsMerge(arg.merge)) {
return errors::InvalidArgument(
"Successor of Enter node for loop-varying argument ",
- arg.merge->name(),
+ FormatNodeForError(*arg.merge),
" is not a Merge node; got: ", arg.merge->type_string());
}
@@ -462,7 +464,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
return errors::InvalidArgument(
"Unexpected number of inputs to Merge node for loop-varying "
"argument ",
- arg.merge->name(), "; expected 2, got ",
+ FormatNodeForError(*arg.merge), "; expected 2, got ",
arg.merge->input_types().size());
}
TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
@@ -470,7 +472,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
if (!IsNextIteration(arg.next_iteration)) {
return errors::InvalidArgument(
"Expected NextIteration node as input to Merge node; got node ",
- arg.next_iteration->name(), " with kind ",
+ FormatNodeForError(*arg.next_iteration), " with kind ",
arg.next_iteration->type_string());
}
@@ -481,14 +483,14 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
switches.find(edge->dst()) != switches.end()) {
if (arg.switch_node != nullptr) {
return errors::InvalidArgument("Duplicate Switch successors to ",
- arg.merge->name());
+ FormatNodeForError(*arg.merge));
}
arg.switch_node = edge->dst();
}
}
if (arg.switch_node == nullptr) {
return errors::InvalidArgument("Missing Switch successor to ",
- arg.merge->name());
+ FormatNodeForError(*arg.merge));
}
// Update the device on the Identity outputs of the switch to match their
@@ -516,14 +518,15 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
possible_exit.pop_front();
if (IsExit(edge->dst())) {
if (arg.exit != nullptr) {
- return errors::InvalidArgument("Duplicate Exit successors to ",
- arg.switch_node->name());
+ return errors::InvalidArgument(
+ "Duplicate Exit successors to ",
+ FormatNodeForError(*arg.switch_node));
}
arg.exit = edge->dst();
} else {
if (!IsIdentity(edge->dst())) {
return errors::Unimplemented("General graph between switch (",
- arg.switch_node->name(),
+ FormatNodeForError(*arg.switch_node),
") and exit node of frame ",
frame->name, " not supported yet.");
}
@@ -1470,7 +1473,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
if (!unreachable_nodes.empty()) {
return errors::InvalidArgument(
"The following nodes are unreachable from the source in the graph: ",
- tensorflow::str_util::Join(unreachable_nodes, ", "));
+ errors::FormatNodeNamesForError(unreachable_nodes));
}
// Builds Frames, indexed by name.
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index aae2f8ee5a..ccf249b35d 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -1064,7 +1064,10 @@ TEST(FunctionalizeControlFlow, Cycle) {
// less -> XlaIf <--> identity.
Status status = FunctionalizeControlFlow(graph.get(), &library);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle"))
+ << status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}"))
<< status.error_message();
}
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index e1cea03865..e4fdf0a618 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 7f3e32d96d..0609e22338 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -123,13 +123,14 @@ tf_kernel_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:prng",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -165,8 +166,8 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -182,7 +183,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -219,8 +220,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
index e335328280..41a453da80 100644
--- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index c4af79281d..b3ad0aea84 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 26130fd9e7..48f2a005ab 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index e9b2c0b16d..41f540506b 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index d6d4ae8937..2c328102e0 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
index efbdb76eaa..5078f8662b 100644
--- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index 62eebf762b..8cc2479dd5 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 1784e712b5..e7fef77edc 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
index 4e6d33304c..547fe48046 100644
--- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index e3a32a5c0e..f410605104 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index f4360d8c3f..da8cf3fc6f 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 48ac4867ed..5da7972397 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
index 500a564f3f..db579a5b35 100644
--- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 9ff3e02228..ef1015552d 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/bcast.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 4f92dbc874..a5b870f8db 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index f314920025..12b0e38288 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 22cda27567..ed44ad218b 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index 3b86ea34c9..a3389d5b90 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index 958231505b..cb73053666 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 81f42e504e..5fdb1d972c 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index 65d42a302f..c68b0bfd79 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index 2fd1a34741..cdba6680de 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index b2b00e51e3..80bcef9663 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index 95faa1d058..54b21a2782 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 5f041be5df..35de96e0aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
index d898e43b85..92346283c3 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index e2160feba0..ceb2af756c 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index cb4caf7bcb..33a73fe5fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -17,7 +17,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
@@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
+class NonMaxSuppressionOp : public XlaOpKernel {
+ public:
+ explicit NonMaxSuppressionOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ // TODO(b/111646731): Improve scalability of this op, using blocking.
+ int num_boxes_dim = 0;
+ int coords_dim = 1;
+ const TensorShape& boxes_shape = context->InputShape("boxes");
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
+ errors::InvalidArgument("boxes must be 2-D, currently: ",
+ boxes_shape.DebugString()));
+ const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim);
+ OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4,
+ errors::InvalidArgument("boxes must have 4 columns",
+ boxes_shape.DebugString()));
+ const TensorShape& scores_shape = context->InputShape("scores");
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
+ errors::InvalidArgument("scores must be 1-D, currently: ",
+ scores_shape.DebugString()));
+ OP_REQUIRES(
+ context, scores_shape.dim_size(0) == num_boxes,
+ errors::InvalidArgument("scores size must equal number of boxes",
+ scores_shape.DebugString()));
+ OP_REQUIRES(context, pad_to_max_output_size_,
+ errors::InvalidArgument(
+ "XLA compilation requires pad_to_max_output_size == True"));
+
+ xla::XlaOp boxes = context->Input("boxes");
+ xla::XlaOp scores = context->Input("scores");
+ int64 output_size;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
+ OP_REQUIRES(
+ context, output_size >= 0,
+ errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ xla::XlaOp score_thresh = context->Input("score_threshold");
+ xla::XlaOp iou_thresh = context->Input("iou_threshold");
+
+ xla::XlaBuilder* const builder = context->builder();
+
+ // Choose a more convenient layout.
+ xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0});
+ coords_dim = 0;
+ num_boxes_dim = 1;
+
+ // Shapes are henceforth [1, num_boxes].
+ xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t,
+ /*start_index=*/0,
+ /*limit_index=*/1,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t,
+ /*start_index=*/1,
+ /*limit_index=*/2,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t,
+ /*start_index=*/2,
+ /*limit_index=*/3,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t,
+ /*start_index=*/3,
+ /*limit_index=*/4,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp y1 =
+ xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1);
+ xla::XlaOp y2 =
+ xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0);
+ xla::XlaOp x1 =
+ xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1);
+ xla::XlaOp x2 =
+ xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0);
+ xla::XlaOp area = (y2 - y1) * (x2 - x1);
+
+ // Transpose the 1xN tensors, instead of the NxN tensors.
+ xla::XlaOp y1_t = xla::Transpose(y1, {1, 0});
+ xla::XlaOp y2_t = xla::Transpose(y2, {1, 0});
+ xla::XlaOp x1_t = xla::Transpose(x1, {1, 0});
+ xla::XlaOp x2_t = xla::Transpose(x2, {1, 0});
+ xla::XlaOp area_t = xla::Transpose(area, {1, 0});
+
+ // Shapes are henceforth [num_boxes, num_boxes].
+ xla::XlaOp i_xmin = xla::Max(x1, x1_t);
+ xla::XlaOp i_ymin = xla::Max(y1, y1_t);
+ xla::XlaOp i_xmax = xla::Min(x2, x2_t);
+ xla::XlaOp i_ymax = xla::Min(y2, y2_t);
+ auto square_zero = xla::ZerosLike(i_xmin);
+
+ xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
+ xla::Max(i_ymax - i_ymin, square_zero);
+ xla::XlaOp u_area = area + area_t - i_area;
+ xla::XlaOp iou = i_area / u_area;
+
+ xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
+ xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1});
+ xla::XlaOp score_cmp_mask =
+ xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0}));
+ xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask);
+
+ // Shapes are [num_boxes] after the reduce.
+ xla::XlaOp included_iou = xla::Not(xla::Reduce(
+ suppress,
+ /*init_value=*/xla::ConstantR0<bool>(builder, false),
+ /*computation=*/CreateScalarOrComputation(xla::PRED, builder),
+ /*dimensions_to_reduce=*/{0}));
+ xla::XlaOp included_score =
+ xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
+ xla::XlaOp included = xla::And(included_iou, included_score);
+ xla::XlaOp neg_inf =
+ xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
+ xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
+
+ xla::XlaOp ones_included = xla::Select(
+ included,
+ xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
+ xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
+
+ // num_valid is scalar.
+ xla::XlaOp num_valid = xla::Reduce(
+ ones_included,
+ /*init_value=*/xla::ConstantR0<int>(builder, 0),
+ /*computation=*/CreateScalarAddComputation(xla::S32, builder),
+ /*dimensions_to_reduce=*/{0});
+
+ xla::XlaOp output_tuple = TopK(scores_included, output_size);
+ xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);
+
+ context->SetOutput(0, selected_indices);
+ context->SetOutput(1, num_valid);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+REGISTER_XLA_OP(
+ Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"),
+ NonMaxSuppressionOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d6bf92fb3d..8d75624e74 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/math/math_util.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index 9e64711051..f028e361bc 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
index 2fb072f827..a11bbe918f 100644
--- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index dc934543cb..87ee2d3aed 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index aa45b02551..6440770c29 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index e06c87db7a..8dfd7de591 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index e2ab4b83cf..c0ca881ff8 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index 529959dbd9..eedfc3c914 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
index 3aed47de26..a9b519d892 100644
--- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 89fd610bc6..e5937b56c1 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 2a4c0cab4b..3d506e71e0 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 2e632e185d..6f4ed496a1 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 607cad798a..2da9340625 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 23ac45beb7..b11a4ce36d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index be7f2bce8c..0d260fa8fc 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 8333f9b288..466e79828d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index bb8dd3ac90..b52f0a0ab6 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index f4b804e546..d35777ccb1 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 354fec9be7..121750a82a 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 5be70a4ded..1911e6ea36 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index ec15b4cc7a..d962ef4a5f 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index c810456f94..03a50ef8a0 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 56f237d588..ab094d7dd1 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
index 14709bb6cb..f1f32699fe 100644
--- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index e2ac7da2c2..b22ecb7c6d 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index 5c010c9df2..6ce50efb4a 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index 6281d6c653..a7f5a8f169 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 5798823cd5..4e0cf99d8e 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 1864584ade..6adc3c58de 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 60c6a5d349..1d7a63dc31 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index faaf8964ff..aaeeae01cc 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 8a8525efa1..7327258c31 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 47d282fe9e..4493539fe3 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 242638f981..93fc14e9ef 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index cc4b13d3b9..5412e13547 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index c2165ccd86..1062399d91 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 26326f18b8..be1814d8e3 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index c9e5694262..1233a37565 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 82d4a69777..183879c760 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -47,31 +47,12 @@ class TopKOp : public XlaOpKernel {
context, last_dim_size >= k,
errors::InvalidArgument("input must have at least k columns. Had ",
last_dim_size, ", needed ", k));
-
- xla::XlaBuilder* const b = context->builder();
if (last_dim_size < k) {
k = last_dim_size;
}
- const xla::XlaOp input = context->Input(0);
-
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
- auto input_dims = input_shape.dim_sizes();
- std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
- xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
- xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
-
- std::vector<int64> start_indices(input_shape.dims(), 0);
- std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
- limit_indices[last_dim] = k;
- std::vector<int64> strides(input_shape.dims(), 1);
-
- xla::XlaOp values =
- xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
- limit_indices, strides));
- xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
- start_indices, limit_indices, strides);
- context->SetOutput(0, values);
- context->SetOutput(1, indices);
+ xla::XlaOp output_tuple = TopK(context->Input(0), k);
+ context->SetOutput(0, xla::GetTupleElement(output_tuple, 0));
+ context->SetOutput(1, xla::GetTupleElement(output_tuple, 1));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 98df730249..be5e911386 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index 6c721c48fe..f9148b3942 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index e6ec794cfd..0bdfc05726 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index f951127bb9..8671632976 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index bb27b5d56f..2c92a585f5 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index c653a11029..1e8a376765 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/function.h"
diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
deleted file mode 100644
index 661505021f..0000000000
--- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's backend registration modules.
-
-#include <mutex> // NOLINT
-#include <vector>
-
-#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static BackendRegistrationFlags* flags;
-static std::vector<Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new BackendRegistrationFlags;
- flags->tf_enable_prng_ops_gpu = false;
- flag_list = new std::vector<Flag>({
- Flag("tf_enable_prng_ops_gpu", &flags->tf_enable_prng_ops_gpu,
- "Whether to enable PRNG ops: [RandomStandardNormal | RandomUniform "
- "| RandomUniformInt | TruncatedNormal] on GPU."),
- });
- xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// backend registration modules.
-void AppendBackendRegistrationFlags(std::vector<Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the BackendRegistrationFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-BackendRegistrationFlags* GetBackendRegistrationFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
deleted file mode 100644
index 861c923dd5..0000000000
--- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
-#define TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
-
-// Legacy flags for the XLA bridge's backend registration modules.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// backend registration modules.
-void AppendBackendRegistrationFlags(std::vector<tensorflow::Flag>* append_to);
-
-// The values of flags associated with the XLA bridge's backend registration
-// module.
-typedef struct {
- // Whether to enable RandomUniform op on GPU backend.
- // TODO (b/32333178): Remove this flag or set its default to true.
- bool tf_enable_prng_ops_gpu;
-} BackendRegistrationFlags;
-
-// Return a pointer to the BackendRegistrationFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-BackendRegistrationFlags* GetBackendRegistrationFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index e35a457f09..cb7a40e23d 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -44,9 +44,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -59,9 +59,9 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:protos_all_cc",
],
)
@@ -78,12 +78,12 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -100,9 +100,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -119,10 +119,10 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:numeric",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -142,7 +142,7 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -162,8 +162,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -200,8 +200,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 3c4eec081b..f666d22ea4 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index dbba5eaf26..8757b16a1c 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 35b137aa2c..87d73eb3f0 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index bc1b0ed82f..1bef9bb166 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index 9c8ac7af25..fc0c1ee838 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index 3aa6a9b075..abd2316ac9 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
index 8ff10fbd3f..5e7cf00ee5 100644
--- a/tensorflow/compiler/tf2xla/lib/random.cc
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h
index 2c573fd85b..59fc5d0433 100644
--- a/tensorflow/compiler/tf2xla/lib/random.h
+++ b/tensorflow/compiler/tf2xla/lib/random.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.pb.h"
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 739032fef7..ba22eff73a 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 452fda565d..13a5f1b850 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <functional>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 05dad759df..04fa10108c 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 9c4314e275..555760b7ef 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index a29496dec4..aeebf16028 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index a6f5d346cb..8b5beba383 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index a139873d32..b4905c9528 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 574e70ddee..d64394f140 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 69cc70bfaf..9493b1f109 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <functional>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index fe7ec633ec..e89f473328 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/platform/mem.h"
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index d0b9e34e16..a6e7882533 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/tf2xla/xla_resource.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 678e209cf6..226c89bcf1 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -28,13 +28,14 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -689,12 +690,12 @@ Status ValidateFunctionDef(const FunctionDef* fdef,
Status ValidateGraph(const Graph* graph,
const FunctionLibraryDefinition& flib_def,
const DeviceType& device_type, const string& name) {
- auto maybe_error = [&](const string& op, const Status& s) -> Status {
+ auto maybe_error = [&](const Node* node, const Status& s) -> Status {
if (!s.ok()) {
return errors::InvalidArgument(strings::StrCat(
"Detected unsupported operations when trying to compile graph ", name,
- " on ", device_type.type_string(), ": ", op, " (", s.error_message(),
- ")"));
+ " on ", device_type.type_string(), ": ", node->def().op(), " (",
+ s.error_message(), ")", FormatNodeForError(*node)));
}
return Status::OK();
};
@@ -707,15 +708,15 @@ Status ValidateGraph(const Graph* graph,
Status s;
if (fdef) {
s = ValidateFunctionDef(fdef, flib_def);
- TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
+ TF_RETURN_IF_ERROR(maybe_error(node, s));
continue;
}
const OpDef* op_def;
s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
- TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
+ TF_RETURN_IF_ERROR(maybe_error(node, s));
TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
- TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
+ TF_RETURN_IF_ERROR(maybe_error(node, s));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index acc64d99d3..25332c8d8e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -252,6 +252,12 @@ class XlaCompiler {
// The default empty value is invalid.
DeviceType device_type = DeviceType("");
+ // The device to use during compilation to execute instructions on, for
+ // example for auto-tuning.
+ // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
+ // -1 indicates the default device should be used.
+ int device_ordinal = -1;
+
xla::Client* client = nullptr;
// Function library in which to find function definitions. Must be non-null.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 2fb93be01d..be00ed8813 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -312,7 +312,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
str_util::StrContains(status.error_message(), "depends on a parameter"))
<< status.error_message();
EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "[[Node: C = Reshape"))
+ str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
<< status.error_message();
}
@@ -1077,6 +1077,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}"))
+ << status.error_message();
}
// Tests a graph which has a node with invalid data type.
@@ -1101,6 +1103,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"is not in the list of allowed values"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}"))
+ << status.error_message();
}
TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
@@ -1122,9 +1126,10 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: NoOp"))
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: {{node NoOp}}"))
<< status.error_message();
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 2836cb3df3..b24e3aabbe 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index beee7d48e8..3db37afdba 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
index dc98d4fda6..1398e9ee53 100644
--- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
@@ -21,20 +20,6 @@ limitations under the License.
namespace tensorflow {
bool GpuOpFilter(KernelDef* kdef) {
- // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
- // slow code.
- legacy_flags::BackendRegistrationFlags* flags =
- legacy_flags::GetBackendRegistrationFlags();
- VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu;
- if (!flags->tf_enable_prng_ops_gpu &&
- (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
- kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) {
- return false;
- }
- // TODO(b/26783907): The GPU backend currently does not implement sort.
- if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
- return false;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 225da16807..8efb3d55c8 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index d6ca4ab934..e6522157a5 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
#include "tensorflow/compiler/tf2xla/xla_context.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 38ec559576..82028c8b9c 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 71990b57d9..ac9dfe3369 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index baea814965..7928fa0347 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index 4de18a7788..2438490be1 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index f1c383fd9e..fdf13bb18c 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -636,7 +636,7 @@ cc_library(
":window_util",
":xla_data_proto",
"//tensorflow/compiler/xla/client:padding",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:shape_inference",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 289d3f552a..ad3fcee05b 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -114,6 +114,7 @@ cc_library(
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",
+ "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:support",
@@ -187,3 +188,47 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto",
],
)
+
+cc_library(
+ name = "xla_builder",
+ srcs = ["xla_builder.cc"],
+ hdrs = ["xla_builder.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":padding",
+ ":sharding_builder",
+ ":xla_computation",
+ "//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "xla_builder_test",
+ srcs = ["xla_builder_test.cc"],
+ deps = [
+ ":xla_builder",
+ ":xla_computation",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/core:test",
+ ],
+)
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 45506986c8..39d5582d19 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -29,8 +29,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -45,7 +45,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
],
)
@@ -58,7 +58,7 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -72,7 +72,7 @@ cc_library(
":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
],
)
@@ -86,7 +86,7 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -101,7 +101,7 @@ cc_library(
":constants",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
@@ -115,7 +115,7 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -131,12 +131,43 @@ cc_library(
":numeric",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
cc_library(
+ name = "sorting",
+ srcs = ["sorting.cc"],
+ hdrs = ["sorting.h"],
+ deps = [
+ ":numeric",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "sorting_test",
+ srcs = ["sorting_test.cc"],
+ blacklisted_backends = [
+ "cpu",
+ "gpu",
+ ],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":sorting",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "testing",
srcs = ["testing.cc"],
hdrs = ["testing.h"],
@@ -150,8 +181,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 1872925aba..9225b1acd6 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 80d3f8b95a..632e8cc8bc 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h
index b47f5243f0..0c8a9b8cc0 100644
--- a/tensorflow/compiler/xla/client/lib/constants.h
+++ b/tensorflow/compiler/xla/client/lib/constants.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <type_traits>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc
index f1e3439862..f4320f65c1 100644
--- a/tensorflow/compiler/xla/client/lib/constants_test.cc
+++ b/tensorflow/compiler/xla/client/lib/constants_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index d003d529cc..13db232556 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
index 1df287d7db..14c259a7fa 100644
--- a/tensorflow/compiler/xla/client/lib/math_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
index 212f658313..efd8cdc257 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.h
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
index f56cadc547..8a96ec68d2 100644
--- a/tensorflow/compiler/xla/client/lib/numeric_test.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index 299a6ac2b6..3a744148fb 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h
index ac86390239..ad000b1fa1 100644
--- a/tensorflow/compiler/xla/client/lib/prng.h
+++ b/tensorflow/compiler/xla/client/lib/prng.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <array>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc
new file mode 100644
index 0000000000..a904be259a
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/sorting.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+
+namespace xla {
+
+XlaOp TopK(XlaOp input, int64 k) {
+ XlaBuilder* const builder = input.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
+ int last_dim = input_shape.dimensions_size() - 1;
+ int last_dim_size = input_shape.dimensions(last_dim);
+
+ XlaOp iota_s32 = Iota(builder, S32, last_dim_size);
+ auto input_dims = input_shape.dimensions();
+ std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
+ XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims);
+ XlaOp sort_result = Sort(Neg(input), broadcast_s32);
+ std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
+ std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+ limit_indices[last_dim] = k;
+ std::vector<int64> strides(input_shape.dimensions_size(), 1);
+
+ XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices,
+ limit_indices, strides));
+ XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
+ limit_indices, strides);
+ return Tuple(builder, {values, indices});
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pool_test.cc b/tensorflow/compiler/xla/client/lib/sorting.h
index 8c4fe258e3..404b4783c3 100644
--- a/tensorflow/compiler/xla/service/pool_test.cc
+++ b/tensorflow/compiler/xla/client/lib/sorting.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,28 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/pool.h"
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
-#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
-namespace {
-using PoolTest = ::testing::Test;
+// Returns a tuple composed of the top `k` values and corresponding indices in
+// `input`. Output values are in descending order, from largest to smallest.
+XlaOp TopK(XlaOp input, int64 k);
-TEST_F(PoolTest, Test) {
- Pool<int> pool;
-
- {
- auto ptr = pool.Allocate();
- EXPECT_NE(nullptr, ptr.get());
- *ptr = 5;
- }
-
- auto ptr = pool.Allocate();
- EXPECT_NE(nullptr, ptr.get());
- EXPECT_EQ(5, *ptr);
-}
-
-} // namespace
} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc
new file mode 100644
index 0000000000..b6eee762a5
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace {
+
+using SortingTest = ClientLibraryTestBase;
+
+XLA_TEST_F(SortingTest, TopK3From8Values) {
+ XlaBuilder builder(TestName());
+ auto x =
+ ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ xla::GetTupleElement(xla::TopK(x, 3), 0);
+ ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
+}
+
+XLA_TEST_F(SortingTest, TopK3From8Indices) {
+ XlaBuilder builder(TestName());
+ auto x_rev =
+ ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
+ xla::GetTupleElement(xla::TopK(x_rev, 3), 1);
+ ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
+}
+
+XLA_TEST_F(SortingTest, TopKFullSort) {
+ XlaBuilder builder(TestName());
+ const int kSize = 16;
+ std::mt19937 eng;
+ std::uniform_real_distribution<float> u_dist(0.0, 100.0);
+ auto gen = std::bind(u_dist, eng);
+ std::vector<float> inputs(kSize);
+ std::generate(inputs.begin(), inputs.end(), gen);
+ auto x = ConstantR1<float>(&builder, inputs);
+ xla::GetTupleElement(xla::TopK(x, kSize), 0);
+
+ std::sort(inputs.begin(), inputs.end(), std::greater<float>());
+ ComputeAndCompareR1<float>(&builder, inputs, {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 2de65016dd..b1a776b8b8 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 035ee9bf4c..e7250e11d5 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/status_macros.h"
using xla::source_map_util::InvalidParameterArgument;
@@ -30,8 +31,8 @@ using xla::source_map_util::InvalidParameterArgument;
namespace xla {
namespace {
-StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
- Backend* backend) {
+StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
+ Backend* backend) {
if (device_ordinal < 0) {
device_ordinal = backend->default_device_ordinal();
}
@@ -142,7 +143,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
TF_RETURN_IF_ERROR(
ValidateExecutionOptions(arguments, run_options, *backend_));
- Backend::StreamPtr stream;
+ StreamPool::Ptr stream;
if (run_options.stream() == nullptr) {
// NB! The lifetime of `stream` needs to match the lifetime of
// `actual_options` (otherwise we will end up using a returned stream in
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 152335e22a..53be5a79c2 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include <functional>
#include <numeric>
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
new file mode 100644
index 0000000000..ae331407d6
--- /dev/null
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -0,0 +1,2241 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
+
+#include <map>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stacktrace.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class XlaBuilder;
+
+// This represents an instruction that has been enqueued using the XlaBuilder.
+// This is used to pass to subsequent computations that depends upon the
+// instruction as an operand.
+class XlaOp {
+ public:
+ XlaOp() : handle_(-1), builder_(nullptr) {
+ static_assert(std::is_trivially_destructible<XlaOp>::value,
+ "XlaOp should be trivially destructible");
+ }
+ ~XlaOp() = default;
+
+ // Precondition: !IsUninitialized().
+ //
+ // It's very common to do foo.builder()->bar(). Without this precondition, if
+ // foo.builder() is null, the call to bar will segfault at some point possibly
+ // deep in the callstack when we finally dereference `this`. The precondition
+ // lets us avoid this tricky-to-debug problem.
+ XlaBuilder* builder() const {
+ CHECK(builder_ != nullptr);
+ return builder_;
+ }
+
+ // Returns true if the XlaOp represents valid, non-erroneous value.
+ bool valid() const { return handle_ >= 0; }
+
+ // Returns true if the XlaOp was created by the XlaOp() constructor and
+ // not returned by a builder.
+ bool IsUninitialized() const { return builder_ == nullptr; }
+
+ bool IsIdenticalTo(const XlaOp& rhs) const {
+ return handle_ == rhs.handle_ && builder_ == rhs.builder_;
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
+ out << op.handle();
+ return out;
+ }
+
+ private:
+ explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
+ XlaOp(int64 handle, XlaBuilder* builder)
+ : handle_(handle), builder_(builder) {}
+
+ int64 handle() const { return handle_; }
+
+ friend class XlaBuilder;
+
+ // < 0 means "invalid handle".
+ int64 handle_;
+
+ // Not owned. Non-null for any handle returned by XlaBuilder, even if the
+ // handle is invalid.
+ XlaBuilder* builder_;
+};
+
+// Arithmetic operator overloads for the XlaOp type.
+XlaOp operator-(const XlaOp& x);
+XlaOp operator+(const XlaOp& x, const XlaOp& y);
+XlaOp operator-(const XlaOp& x, const XlaOp& y);
+XlaOp operator*(const XlaOp& x, const XlaOp& y);
+XlaOp operator/(const XlaOp& x, const XlaOp& y);
+XlaOp operator%(const XlaOp& x, const XlaOp& y);
+
+// Bitwise operator overloads for the XlaOp type.
+XlaOp operator~(const XlaOp& x);
+XlaOp operator&(const XlaOp& x, const XlaOp& y);
+XlaOp operator|(const XlaOp& x, const XlaOp& y);
+XlaOp operator^(const XlaOp& x, const XlaOp& y);
+XlaOp operator<<(const XlaOp& x, const XlaOp& y);
+// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
+// a right logical shift.
+XlaOp operator>>(const XlaOp& x, const XlaOp& y);
+
+// We don't overload the relational operators (==, !=, <, <=, >, >=) because the
+// semantics might be surprising since their result types are usually 'bool'.
+// Further programmers may expect == to be a structural equality.
+// We also choose not to overload any of the mutating operators (e.g., +=, -=)
+// because the semantics might be misleading — XLA computations are immutable.
+
+// A convenient interface for building up computations.
+//
+// Thread-compatible.
+class XlaBuilder {
+ public:
+ // computation_name: name to use for the built computation.
+ XlaBuilder(const string& computation_name);
+
+ XlaBuilder(const XlaBuilder&) = delete;
+ XlaBuilder& operator=(const XlaBuilder&) = delete;
+
+ ~XlaBuilder();
+
+ // Returns the computation name.
+ const string& name() const { return name_; }
+
+ // Sets OpMetadata that will be added to all instructions until cleared.
+ //
+ // OpMetadata is often applied to a series of XLA HLO instructions. As a
+ // result, OpMetadata is set on the Computation Builder. All subsequent
+ // instructions generated via this Computation Builder will have the same
+ // OpMetadata attached until a call to ClearOpMetadata.
+ void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
+
+ // Clears the HloMetadata state.
+ void ClearOpMetadata() { metadata_.Clear(); }
+
+ // Sets an OpSharding that will be attached to all instructions until cleared.
+ void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
+
+ // Clears the sharding. Ops will be sharded according to the default placement
+ // policy.
+ void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
+
+ // Returns the OpSharding that will be attached to all instructions.
+ const tensorflow::gtl::optional<OpSharding>& sharding() const {
+ return sharding_;
+ }
+
+ // Sets the builder to a mode where it will die immediately when an error is
+ // encountered, rather than producing it in a deferred fashion when Build() is
+ // called (which is the default).
+ void set_die_immediately_on_error(bool enabled) {
+ die_immediately_on_error_ = enabled;
+ }
+
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Returns an error if the convolution dimension numbers have conflicts.
+ static Status Validate(const ConvolutionDimensionNumbers& dnum);
+
+ // Returns a new XlaBuilder whose resultant Computation is used only by this
+ // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
+ // behavior as the parent.
+ std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status. Note that all ops that have been enqueued will be moved to the
+ // computation being returned.
+ StatusOr<XlaComputation> Build();
+
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent XlaBuilder and returns an empty computation if building failed.
+ // This function is intended to be used where the returned XlaComputation is
+ // only used by the parent XlaBuilder and hence further operation on the
+ // returned XlaComputation will simply be error'ed out if an error occurred
+ // while building this computation. If the built computation is to be used by
+ // a XlaBuilder other than the parent XlaBuilder then Build() should be used
+ // instead.
+ XlaComputation BuildAndNoteError();
+
+ // Returns a subgraph that roots on the given root. If the root is not a
+ // compile-time constant (see `IsConstant`), returns an error.
+ //
+ // This will copy the needed ops/computations to the subgraph.
+ StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // XlaOp and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
+ // Returns the shape of the given op.
+ StatusOr<Shape> GetShape(const XlaOp& op) const;
+
+ // Returns the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape() const;
+
+ // Reports an error to the builder, by
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to
+ // Build()/GetShape()/...)
+ // * dying if die_immediately_on_error_ is true.
+ // Returns an XlaOp with an invalid handle but a valid builder. This value can
+ // be returned in place of a value in APIs that return an XlaOp.
+ XlaOp ReportError(const Status& error);
+
+ // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
+ // If the Status was an error, reports the error to builder and returns an
+ // invalid XlaOp handle.
+ XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
+
+ // A helper function that runs a function that returns a StatusOr<XlaOp> and
+ // returns an XlaOp.
+ XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on any parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`.
+ //
+ // This tests whether a computation is a compile-time constant without
+ // evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand) const;
+
+ private:
+ // Enqueues a "retrieve parameter value" instruction for a parameter that was
+ // passed to the computation.
+ XlaOp Parameter(int64 parameter_number, const Shape& shape,
+ const string& name);
+
+ // Enqueues a constant with the value of the given literal onto the
+ // computation.
+ XlaOp ConstantLiteral(const LiteralSlice& literal);
+
+ // Enqueues a constant onto the computation. Methods are templated on the
+ // native host type (NativeT) which corresponds to a specific XLA
+ // PrimitiveType as given in the following table:
+ //
+ // Native Type PrimitiveType
+ // -----------------------------
+ // bool PRED
+ // int32 S32
+ // int64 S64
+ // uint32 U32
+ // uint64 U64
+ // float F32
+ // double F64
+ //
+ // Note: not all primitive types defined in xla_data.proto have a
+ // corresponding native type yet.
+ template <typename NativeT>
+ XlaOp ConstantR0(NativeT value);
+ template <typename NativeT>
+ XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ XlaOp ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Enqueues a rank one constant (vector) onto the computation. The vector has
+ // size 'length' and every element has the value 'value'.
+ template <typename NativeT>
+ XlaOp ConstantR1(int64 length, NativeT value);
+
+ // Adds dimensions to an array by duplicating the data in the array.
+ //
+ // The new dimensions are inserted on the left, i.e. if
+ // broadcast_sizes has values {a0, ..., aN} and the operand shape
+ // has dimensions {b0, ..., bM} then the shape of the output has
+ // dimensions {a0, ..., aN, b0, ..., bM}.
+ //
+ // The new dimensions index into copies of the operand, i.e.
+ //
+ // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+ XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ // Performs in-dimension-style broadcast.
+ //
+ // Operand specifies the input to be broadcast. "shape" is expected output
+ // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+ // Dimension numbers in broadcast_dimensions map to individual dimensions
+ // of the operand, and specify what dimension of the output shape they
+ // should be broadcast.
+ // e.g.
+ // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+ // and dimension of shape is [2,2].
+ // Specifying {1} as brodcast_dimension will generate output
+ // [1 , 2]
+ // [1 , 2]
+ // On the other hand, specifying {0} as broadcast_dimension
+ // will generate output
+ // [1 , 1]
+ // [2 , 2]
+ XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Enqueues a pad operation onto the computation that pads the given value on
+ // the edges as well as between the elements of the input. padding_config
+ // specifies the padding amount for each dimension.
+ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ // Enqueues an operation onto the computation that flattens the operand based
+ // on the dimension order (major/slowest-varying to minor/fastest-varying)
+ // given, followed by reshaping it into the shape with the given dimension
+ // sizes (also major to minor). Conceptually, this is a limited form of
+ // "shape casting".
+ XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Enqueues an operation onto the computation that collapses the operand, from
+ // first to last dimension (C order), then reshapes it to the given dimension
+ // sizes. Conceptually, this is a limited form of "shape casting".
+ XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Wrapper for Reshape.
+ // Enqueues an operation to collapse the provided dimensions; e.g. an
+ // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
+ // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
+ // be a consecutive, in-order subsequence of the operand dimensions.
+ //
+ // Note that collapsing a single dimension does nothing:
+ //
+ // {256} collapsing {0} => {256}
+ // {1} collapsing {0} => {1}
+ //
+ // Collapsing multiple dimensions produces a single result dimension:
+ //
+ // {256, 2} collapsing {0,1} => {512}
+ // {256, 2, 3} collapsing {0,1} => {512, 3}
+ //
+ // This could potentially cause data to be moved -- it provides a more
+ // structured form of reshaping than an arbitrary Reshape operation.
+ XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a slice operation onto the computation that slices the operand
+ // from the start indices to the limit indices; e.g.
+ //
+ // x
+ // [ 0 1 2 3 ]
+ // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
+ // [ 8 9 a b ]
+ //
+ // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
+ // range notation.
+ // The strides parameter determines the stride over the slice
+ XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ // Enqueues a slice operation in a given dimension, taking all other
+ // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
+ // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
+ // for:
+ //
+ // array[:, 2:4:1, :]
+ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno);
+
+ // Enqueues a slice operation onto the computation that slices the 'operand'
+ // from dynamic start indices which are passed in 'start_indices'.
+ // The size of the slice in each dimension is passed in 'slice_sizes',
+ // which specify the end point of exclusive slice intervals in each
+ // dimension [start, start + size).
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo input dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ // Enqueues a dynamic update slice operation onto the computation, which
+ // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
+ // The shape of 'update' determines the shape of the slice of 'operand'
+ // which is updated.
+ // The indices specified in 'start_indices' specify the offset of the slice
+ // of 'operand' which is updated.
+ //
+ // update = {10, 11} // calculated at runtime.
+ // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
+ // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
+ // [7 8 9] [7 8 9 ]
+ //
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo update dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ // Enqueues a concatenate instruction onto the computation. 'operands' must
+ // have >= 1 entry.
+ XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ // Enqueue a tracing operation onto the computation; the computation will emit
+ // a logging message with the operand.
+ void Trace(const string& tag, const XlaOp& operand);
+
+ // Enqueues a conditional-move-like select operation onto the computation;
+ // predicated on pred, selects between on_true and on_false.
+ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
+
+ // Enqueues a tuple-creation instruction onto the computation.
+ XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements);
+
+ // Enqueues a tuple-element-get instruction onto the computation.
+ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+
+ // Enqueues an equal-to comparison instruction onto the computation.
+ XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a not-equal comparison instruction onto the computation.
+ XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-or-equal comparison instruction onto the computation.
+ XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-than comparison instruction onto the computation.
+ XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-than comparison instruction onto the computation.
+ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-or-equal comparison instruction onto the computation.
+ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a dot instruction onto the computation.
+ XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+
+ // Enqueues a general dot instruction onto the computation.
+ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, which uses the
+ // default convolution dimension numbers.
+ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration in the format returned by MakePadding().
+ XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided dimension numbers configuration.
+ XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration as well as the dimension numbers.
+ XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration, dilation factors and dimension numbers.
+ XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues an FFT instruction onto the computation, of the given type and
+ // with the given FFT length.
+ XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
+ // Enqueues an infeed instruction onto the computation, which writes data of
+ // the given shape to the infeed buffer of the device.
+ XlaOp Infeed(const Shape& shape, const string& config = "");
+ XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
+
+ // Enqueues an outfeed instruction onto the computation. This instruction
+ // generates outgoing data transfers for the given data.
+ //
+ // shape_with_layout communicates the laid out shape that we want to outfeed
+ // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
+ // will occur.
+ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+ // Enqueues a call instruction onto the computation.
+ XlaOp Call(const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+ // Enqueues a custom call instruction onto the computation.
+ // During code generation, a call instruction is emitted which targets a
+ // symbol with the name |call_target_name|. The |operands| are passed to the
+ // call instruction. |shape| is the resultant shape.
+ XlaOp CustomCall(const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+
+ // Enqueues a pseudo-op to represent host-side computation data-dependencies.
+ // During code generation, host send and receive operations will be generated
+ // to transfer |operands| to the host and a single result of |shape| back to
+ // the device. Host send/recv operations are emitted using |channel_name|.
+ // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
+ // instruction scheduling.
+ XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+
+ // The following methods enqueue element-wise binary arithmetic operations
+ // onto the computation. The shapes of the operands have to match unless one
+ // of the operands is a scalar, or an explicit broadcast dimension is given
+ // (see g3doc for more details).
+
+ // Enqueues a complex compose instruction onto the computation.
+ XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a complex conjugate instruction onto the computation.
+ XlaOp Conj(const XlaOp& operand);
+
+ // Enqueues an add instruction onto the computation.
+ XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a subtract instruction onto the computation.
+ XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a multiply instruction onto the computation.
+ XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a divide instruction onto the computation.
+ XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a remainder instruction onto the computation.
+ XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a max instruction onto the computation.
+ XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a min instruction onto the computation.
+ XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Element-wise logical operators
+ XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Not(const XlaOp& operand);
+
+ XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Reduces an array among the provided dimensions, given "computation" as a
+ // reduction operator.
+ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+ // Convenience wrapper around the above that reduces all the dimensions in the
+ // operand shape.
+ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+
+ // Enqueues a windowed reduce instruction onto the computation.
+ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+ // As ReduceWindow(), but the padding is given in the format
+ // returned by MakePadding().
+ XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Returns the sum of the operand value within each subgroup of replicas. All
+ // replicas supply one input to the sum and all replicas receive the resulting
+ // sum for each subgroup.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+
+ // Enqueues an operation that do an AllReduce of the operand cross cores. Here
+ // AllReduce means doing a reduction on the input operand cross cores and then
+ // broadcasting the reduction result to those cores. The reduction function is
+ // defined by `computation`, which should be a commutative computation on
+ // scalars, e.g., add, min, or max. The way that AllReduce is applied is
+ // configured by:
+ //
+ // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+ // replicas belong to one group. Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ //
+ // - `channel_id`: for Allreduce nodes from different models, if they have the
+ // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross models.
+ //
+ // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id =
+ tensorflow::gtl::nullopt);
+
+ // Enqueues an operation that scatters the `source` array to the selected
+ // indices of each window.
+ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+ // As SelectAndScatter(), but the padding is given in the format
+ // returned by MakePadding().
+ XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+ // Enqueues an abs instruction onto the computation.
+ XlaOp Abs(const XlaOp& operand);
+
+ // Enqueues a atan2 instruction onto the computation.
+ XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues an exp instruction onto the computation.
+ XlaOp Exp(const XlaOp& operand);
+
+ // Enqueues an expm1 instruction onto the computation.
+ XlaOp Expm1(const XlaOp& operand);
+
+ // Enqueues a floor instruction onto the computation.
+ XlaOp Floor(const XlaOp& operand);
+
+ // Enqueues a ceil instruction onto the computation.
+ XlaOp Ceil(const XlaOp& operand);
+
+ // Enqueues a round instruction onto the computation, rounding to nearest even
+ // with half-way cases rounding away from zero.
+ XlaOp Round(const XlaOp& operand);
+
+ // Enqueues an log instruction (natural logarithm) onto the computation.
+ XlaOp Log(const XlaOp& operand);
+
+ // Enqueues an log1p instruction (log(x+1)) onto the computation.
+ XlaOp Log1p(const XlaOp& operand);
+
+ // Enqueues a sign instruction onto the computation.
+ XlaOp Sign(const XlaOp& operand);
+
+ // Enqueues a count leading zeros instruction onto the computation.
+ XlaOp Clz(const XlaOp& operand);
+
+ // Enqueues a cosine instruction onto the computation.
+ XlaOp Cos(const XlaOp& operand);
+
+ // Enqueues a sine instruction onto the computation.
+ XlaOp Sin(const XlaOp& operand);
+
+ // Enqueues a tanh instruction onto the computation.
+ XlaOp Tanh(const XlaOp& operand);
+
+ // Enqueues a real-part instruction onto the computation.
+ XlaOp Real(const XlaOp& operand);
+
+ // Enqueues an imaginary-part instruction onto the computation.
+ XlaOp Imag(const XlaOp& operand);
+
+ // Enqueues a lhs^rhs computation onto the computation.
+ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues an operator that tests if the operand's values are finite, i.e.,
+ // not Inf or NaN. Defined only for floating-point types. Returns an array of
+ // booleans with the same shape where entries are true iff the corresponding
+ // entry was NaN.
+ XlaOp IsFinite(const XlaOp& operand);
+
+ // Enqueues a convert instruction onto the computation that changes the
+ // element type of the operand array to primitive_type.
+ XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+
+ // Enqueues a no-op instruction onto the computation that changes
+ // the element type of the operand array to primitive_type. The
+ // bit-widths of the source and destination element types must be
+ // identical.
+ XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+
+ // Enqueues a negate instruction onto the computation.
+ XlaOp Neg(const XlaOp& operand);
+
+ // Enqueues a transpose instruction onto the computation.
+ XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+ // Enqueues a reverse instruction onto the computation. The order of the
+ // elements in the given dimensions is reversed (i.e., the element at index i
+ // is moved to index dimension_size - 1 - i).
+ XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a sort (as increasing order) instruction onto the computation.
+ // If only keys are provided:
+ // * If the keys are an rank-1 tensor (an array), the result is a sorted array
+ // of keys, in ascending order.
+ // * If the keys have higher rank, the keys are sorted along the provided
+ // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+ // value of 0 will indepenently sort every column, and a dimension value of 1
+ // will independently sort each row. If no dimension number is provided, then
+ // the last dimension is chosen by default.
+ //
+ // If both keys and values are provided:
+ // * The keys and the values must tensors with the same dimensions. The
+ // element types of the tensors may be different.
+ // * The result is a tuple that consists of a sorted tensor of keys (along the
+ // provided dimension, as above) as the first element, and a tensor with their
+ // corresponding values as the second element.
+ XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
+
+ // Enqueues a clamp instruction onto the computation.
+ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+
+ // Enqueues a map instruction onto the computation.
+ XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+
+ // Enqueues a N(mu, sigma) random number generation instruction onto the
+ // computation.
+ XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
+
+ // Enqueues a U(a, b) random number generation instruction onto the
+ // computation. Returns values in the semi-open interval [a, b).
+ XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+
+ // Enqueues a while node onto the computation.
+ XlaOp While(const XlaComputation& condition, const XlaComputation& body,
+ const XlaOp& init);
+
+ // Enqueues a conditional node onto the computation.
+ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+
+ // Enqueues a ReducePrecision node onto the computation.
+ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+
+ // Enqueues a Gather node onto the computation.
+ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
+ // Enqueues a Send node onto the computation for device-to-device
+ // communication, to send the given operand to a Recv instruction that shares
+ // the same channel handle.
+ void Send(const XlaOp& operand, const ChannelHandle& handle);
+ XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+ // Enqueues a Send node which sends data to the host.
+ XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+ // Enqueues a Recv node which receives data from the host.
+ XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp CreateToken();
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
+
+ // Enqueues a Recv node onto the computation. The data comes from a Send
+ // instruction that shares the same channel handle and its shape must
+ // be the same as the given shape.
+ XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
+ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+ // Normalizes operand across spatial and batch dimensions for each feature.
+ //
+ // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
+ // is the normalized result and batch_mean and batch_var are the mean and
+ // variance, respectively, across batch for the operand.
+ XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+
+ // Normalizes operand across spatial and batch dimensions for each feature.
+ //
+ // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
+ // computing `mean` and `variance` for each batch inside the operation. It
+ // uses the input `mean` and `variance` instead as estimated values. The
+ // purpose of this op is to reduce latency in inference, hence the name
+ // `BatchNormInference`.
+ //
+ // The output has the same shape as `operand`, and contains the normalized
+ // values for each batch.
+ XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+
+ // Calculates the gradients of a batch norm op.
+ //
+ // The inputs `batch_mean` and `batch_var` represent the mean and variance
+ // across the batch.
+ //
+ // Returns a tuple of three elements:
+ // - grad_operand: Gradient with respect to input `operand`
+ // - grad_offset: Gradient with respect to input `offset`
+ // - grad_scale: Gradient with respect to input `scale`
+ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+
+ StatusOr<XlaOp> AddInstruction(
+ HloInstructionProto&& instr, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<XlaOp> operands = {});
+
+ void AddCalledComputation(const XlaComputation& computation,
+ HloInstructionProto* instr);
+
+ StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+
+ // Internal helper method that does the building for an arbitrary unary op.
+ XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
+
+ // Internal helper method that does the building for an arbitrary binary op.
+ // broadcast_dimensions specifies which dimensions to use for broadcasting
+ // when the operation is between tensors of different ranks.
+ XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Internal helper method that does the building for an arbitrary ternary op.
+ XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
+ const XlaOp& ehs);
+
+ XlaOp RngOp(RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ const Shape& shape);
+
+ StatusOr<XlaOp> InDimBroadcast(
+ const Shape& shape, const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Internal helper method that creates a sequence of instructions that
+ // performs an explicit broadcast of the operand to the target shape.
+ StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
+ const XlaOp& operand);
+
+ // Internal helper method for creating a Reshape op with the already inferred
+ // shape.
+ StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
+
+ // Returns the (inferred) result for the program shape for the current
+ // computation and fills the root_id in the pointer.
+ StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
+
+ // Returns shapes for the operands.
+ StatusOr<std::vector<Shape>> GetOperandShapes(
+ tensorflow::gtl::ArraySlice<XlaOp> operands) const;
+
+ // A visitor which checks whether an operation is a compile-time constant,
+ // meaning that it doesn't depend on any parameters, or on any stateful
+ // operation such as `RngNormal` or `Infeed`. The visitor walks the
+ // computation starting at a given operation and sets is_constant to false iff
+ // a parameter or stateful operation is encountered.
+ void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
+ bool* is_constant) const;
+
+ // Checks bounds for convolution parameters.
+ Status VerifyConvolution(
+ const Shape& lhs_shape, const Shape& rhs_shape,
+ const ConvolutionDimensionNumbers& dimension_numbers) const;
+
+ // Helper function for creating a Window proto from user-supplied data.
+ // Returns error if the user-supplied data was invalid.
+ StatusOr<Window> MakeWindow(
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation) const;
+
+ string name_; // Name to use for the built computation.
+
+ // The first error encountered while building the computation.
+ // This is OK until the first error is encountered.
+ Status first_error_;
+
+ // The saved stack trace from the point at which the first error occurred.
+ tensorflow::SavedStackTrace first_error_backtrace_;
+
+ // The instructions of this computation.
+ std::vector<HloInstructionProto> instructions_;
+
+ // The embedded computations used by this computation. Each computation was
+ // the entry computation of some XlaComputation, the key is the unique id of
+ // that XlaComputation.
+ std::map<int64, HloComputationProto> embedded_;
+
+ // The unique parameter numbers.
+ tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+
+ // The metadata to attach to each op. This is structured as a "modal"-like
+ // operation, in order to simplify client code (and not sprinkle this metadata
+ // throughout the TensorFlow op kernel implementations).
+ OpMetadata metadata_;
+
+ // Sharding for this operator. This is structured as a "model"-like operation,
+ // in order to simplify client code, similar to metadata_.
+ tensorflow::gtl::optional<OpSharding> sharding_;
+
+ // Mode bit that indicates whether to die when a first error is encountered.
+ bool die_immediately_on_error_ = false;
+
+ XlaBuilder* parent_builder_{nullptr};
+
+ friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
+ const Shape& shape, const string& name);
+ friend XlaOp ConstantLiteral(XlaBuilder* builder,
+ const LiteralSlice& literal);
+ template <typename NativeT>
+ friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2(
+ XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArray(XlaBuilder* builder,
+ const Array<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+ friend XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ friend XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ friend XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno);
+
+ friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ friend XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ friend void Trace(const string& tag, const XlaOp& operand);
+
+ friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
+ const XlaOp& on_false);
+ friend XlaOp Tuple(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+ friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+ friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+ friend XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config);
+ friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+ friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+ friend XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+ friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Conj(const XlaOp& operand);
+ friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Not(const XlaOp& operand);
+ friend XlaOp ShiftLeft(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+ friend XlaOp ReduceWindow(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ friend XlaOp SelectAndScatter(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp Abs(const XlaOp& operand);
+ friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Exp(const XlaOp& operand);
+ friend XlaOp Expm1(const XlaOp& operand);
+ friend XlaOp Floor(const XlaOp& operand);
+ friend XlaOp Ceil(const XlaOp& operand);
+ friend XlaOp Round(const XlaOp& operand);
+ friend XlaOp Log(const XlaOp& operand);
+ friend XlaOp Log1p(const XlaOp& operand);
+ friend XlaOp Sign(const XlaOp& operand);
+ friend XlaOp Clz(const XlaOp& operand);
+ friend XlaOp Cos(const XlaOp& operand);
+ friend XlaOp Sin(const XlaOp& operand);
+ friend XlaOp Tanh(const XlaOp& operand);
+ friend XlaOp Real(const XlaOp& operand);
+ friend XlaOp Imag(const XlaOp& operand);
+ friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp IsFinite(const XlaOp& operand);
+ // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
+ // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
+ friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+ friend XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp Neg(const XlaOp& operand);
+ friend XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension);
+ friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+ friend XlaOp Map(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
+ const Shape& shape);
+ friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+ friend XlaOp While(const XlaComputation& condition,
+ const XlaComputation& body, const XlaOp& init);
+ friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+ friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend void Send(const XlaOp& operand, const ChannelHandle& handle);
+ friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+ friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+ friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const ChannelHandle& handle);
+ friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config);
+ friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp CreateToken(XlaBuilder* builder);
+ friend XlaOp AfterAll(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> tokens);
+};
+
+// RAII-style object: sets the current sharding assignment in builder on
+// construction, and sets back to the previous assignment on destruction.
+class XlaScopedShardingAssignment {
+ public:
+ XlaScopedShardingAssignment(xla::XlaBuilder* builder,
+ tensorflow::gtl::optional<OpSharding> sharding)
+ : builder_(builder), prev_sharding_(builder->sharding()) {
+ SetSharding(sharding);
+ }
+
+ XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
+ XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
+ delete;
+
+ ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
+
+ private:
+ void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
+ if (sharding.has_value()) {
+ builder_->SetSharding(sharding.value());
+ } else {
+ builder_->ClearSharding();
+ }
+ }
+
+ xla::XlaBuilder* const builder_;
+ tensorflow::gtl::optional<OpSharding> prev_sharding_;
+};
+
+// Free functions for building XlaOps. The intention is that these will
+// become the public API for building XlaOps rather than calling methods on
+// XlaBuilder directly.
+
+// Enqueues a "retrieve parameter value" instruction for a parameter that was
+// passed to the computation.
+XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
+ const string& name);
+
+// Enqueues a constant with the value of the given literal onto the
+// computation.
+XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
+
+// Enqueues a constant onto the computation. Methods are templated on the
+// native host type (NativeT) which corresponds to a specific XLA
+// PrimitiveType as given in the following table:
+//
+// Native Type PrimitiveType
+// -----------------------------
+// bool PRED
+// int32 S32
+// int64 S64
+// uint32 U32
+// uint64 U64
+// float F32
+// double F64
+//
+// Note: not all primitive types defined in xla_data.proto have a
+// corresponding native type yet.
+template <typename NativeT>
+XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
+template <typename NativeT>
+XlaOp ConstantR2(XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+template <typename NativeT>
+XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
+// computation. The vector has size 'length' and every element has the value
+// 'value'.
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+// Adds dimensions to an array by duplicating the data in the array.
+//
+// The new dimensions are inserted on the left, i.e. if
+// broadcast_sizes has values {a0, ..., aN} and the operand shape
+// has dimensions {b0, ..., bM} then the shape of the output has
+// dimensions {a0, ..., aN, b0, ..., bM}.
+//
+// The new dimensions index into copies of the operand, i.e.
+//
+// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+// Performs in-dimension-style broadcast.
+//
+// Operand specifies the input to be broadcast. "shape" is expected output
+// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+// Dimension numbers in broadcast_dimensions map to individual dimensions
+// of the operand, and specify what dimension of the output shape they
+// should be broadcast.
+// e.g.
+// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+// and dimension of shape is [2,2].
+// Specifying {1} as brodcast_dimension will generate output
+// [1 , 2]
+// [1 , 2]
+// On the other hand, specifying {0} as broadcast_dimension
+// will generate output
+// [1 , 1]
+// [2 , 2]
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+// Enqueues a pad operation onto the computation that pads the given value on
+// the edges as well as between the elements of the input. padding_config
+// specifies the padding amount for each dimension.
+XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+// Enqueues an operation onto the computation that flattens the operand based
+// on the dimension order (major/slowest-varying to minor/fastest-varying)
+// given, followed by reshaping it into the shape with the given dimension
+// sizes (also major to minor). Conceptually, this is a limited form of
+// "shape casting".
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+// Enqueues an operation onto the computation that collapses the operand, from
+// first to last dimension (C order), then reshapes it to the given dimension
+// sizes. Conceptually, this is a limited form of "shape casting".
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+// Wrapper for Reshape.
+// Enqueues an operation to collapse the provided dimensions; e.g. an
+// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
+// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
+// be a consecutive, in-order subsequence of the operand dimensions.
+//
+// Note that collapsing a single dimension does nothing:
+//
+// {256} collapsing {0} => {256}
+// {1} collapsing {0} => {1}
+//
+// Collapsing multiple dimensions produces a single result dimension:
+//
+// {256, 2} collapsing {0,1} => {512}
+// {256, 2, 3} collapsing {0,1} => {512, 3}
+//
+// This could potentially cause data to be moved -- it provides a more
+// structured form of reshaping than an arbitrary Reshape operation.
+XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+// Enqueues a slice operation onto the computation that slices the operand
+// from the start indices to the limit indices; e.g.
+//
+// x
+// [ 0 1 2 3 ]
+// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
+// [ 8 9 a b ]
+//
+// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
+// range notation.
+// The strides parameter determines the stride over the slice
+XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+// Enqueues a slice operation in a given dimension, taking all other
+// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
+// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
+// for:
+//
+// array[:, 2:4:1, :]
+XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno);
+
+// Enqueues a slice operation onto the computation that slices the 'operand'
+// from dynamic start indices which are passed in 'start_indices'.
+// The size of the slice in each dimension is passed in 'slice_sizes',
+// which specify the end point of exclusive slice intervals in each
+// dimension [start, start + size).
+// The shape of 'start_indices' must be rank == 1, with dimension size
+// equal to the rank of the 'operand'.
+// Slice index calculations are computed modulo input dimension sizes to
+// prevent dynamic start indices from generating out-of-bound array accesses.
+XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+// Enqueues a dynamic update slice operation onto the computation, which
+// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
+// The shape of 'update' determines the shape of the slice of 'operand'
+// which is updated.
+// The indices specified in 'start_indices' specify the offset of the slice
+// of 'operand' which is updated.
+//
+// update = {10, 11} // calculated at runtime.
+// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
+// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
+// [7 8 9] [7 8 9 ]
+//
+// The shape of 'start_indices' must be rank == 1, with dimension size
+// equal to the rank of the 'operand'.
+// Slice index calculations are computed modulo update dimension sizes to
+// prevent dynamic start indices from generating out-of-bound array accesses.
+XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+// Enqueues a concatenate instruction onto the computation. 'operands' must
+// have >= 1 entry.
+XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
+
+// Enqueue a tracing operation onto the computation; the computation will emit
+// a logging message with the operand.
+void Trace(const string& tag, const XlaOp& operand);
+
+// Enqueues a conditional-move-like select operation onto the computation;
+// predicated on pred, selects between on_true and on_false.
+XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
+
+// Enqueues a tuple-creation instruction onto the computation.
+XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
+
+// Enqueues a tuple-element-get instruction onto the computation.
+XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+
+// Enqueues an equal-to comparison instruction onto the computation.
+XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a not-equal comparison instruction onto the computation.
+XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a greater-or-equal comparison instruction onto the computation.
+XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a greater-than comparison instruction onto the computation.
+XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a less-than comparison instruction onto the computation.
+XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a less-or-equal comparison instruction onto the computation.
+XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a dot instruction onto the computation.
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+
+// Enqueues a general dot instruction onto the computation.
+XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, which uses the
+// default convolution dimension numbers.
+XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration in the format returned by MakePadding().
+XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided dimension numbers configuration.
+XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration as well as the dimension numbers.
+XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration, dilation factors and dimension numbers.
+XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues an FFT instruction onto the computation, of the given type and
+// with the given FFT length.
+XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
+// Enqueues an infeed instruction onto the computation, which writes data of
+// the given shape to the infeed buffer of the device.
+XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config = "");
+
+// Variant of Infeed which takes a token-shaped operand and produces a
+// two-element tuple containing the data value and a token-shaped value.
+// Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
+
+// Enqueues an outfeed instruction onto the computation. This instruction
+// generates outgoing data transfers for the given data.
+//
+// shape_with_layout communicates the laid out shape that we want to outfeed
+// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
+// will occur.
+void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+// Variant of Outfeed which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+// Enqueues a call instruction onto the computation.
+XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+// Enqueues a custom call instruction onto the computation.
+// During code generation, a call instruction is emitted which targets a
+// symbol with the name |call_target_name|. The |operands| are passed to the
+// call instruction. |shape| is the resultant shape.
+XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+
+// Enqueues a pseudo-op to represent host-side computation data-dependencies.
+// During code generation, host send and receive operations will be generated
+// to transfer |operands| to the host and a single result of |shape| back to
+// the device. Host send/recv operations are emitted using |channel_name|.
+// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
+// instruction scheduling.
+XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+
+// The following methods enqueue element-wise binary arithmetic operations
+// onto the computation. The shapes of the operands have to match unless one
+// of the operands is a scalar, or an explicit broadcast dimension is given
+// (see g3doc for more details).
+
+// Enqueues a complex compose instruction onto the computation.
+XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a complex conjugate instruction onto the computation.
+XlaOp Conj(const XlaOp& operand);
+
+// Enqueues an add instruction onto the computation.
+XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a subtract instruction onto the computation.
+XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a multiply instruction onto the computation.
+XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a divide instruction onto the computation.
+XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a remainder instruction onto the computation.
+XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a max instruction onto the computation.
+XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a min instruction onto the computation.
+XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Element-wise logical operators
+XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Not(const XlaOp& operand);
+
+XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Reduces an array among the provided dimensions, given "computation" as a
+// reduction operator.
+XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+// Convenience wrapper around the above that reduces all the dimensions in the
+// operand shape.
+XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+
+// Enqueues a windowed reduce instruction onto the computation.
+XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+// As ReduceWindow(), but the padding is given in the format
+// returned by MakePadding().
+XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+// Returns the sum of the operand value within each subgroup of replicas. All
+// replicas supply one input to the sum and all replicas receive the resulting
+// sum for each subgroup.
+XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+
+// Enqueues an operation that do an AllReduce of the operand cross cores. Here
+// AllReduce means doing a reduction on the input operand cross cores and then
+// broadcasting the reduction result to those cores. The reduction function is
+// defined by `computation`, which should be a commutative computation on
+// scalars, e.g., add, min, or max. The way that AllReduce is applied is
+// configured by:
+//
+// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+// replicas belong to one group. Allreduce will be applied within subgroups.
+// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+//
+// - `channel_id`: for Allreduce nodes from different models, if they have the
+// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+// applied cross models.
+//
+// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
+XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<ChannelHandle>&
+ channel_id = tensorflow::gtl::nullopt);
+
+// Enqueues an operation that scatters the `source` array to the selected
+// indices of each window.
+XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
+
+// As SelectAndScatter(), but the padding is given in the format
+// returned by MakePadding().
+XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+// Enqueues an abs instruction onto the computation.
+XlaOp Abs(const XlaOp& operand);
+
+// Enqueues a atan2 instruction onto the computation.
+XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues an exp instruction onto the computation.
+XlaOp Exp(const XlaOp& operand);
+
+// Enqueues an expm1 instruction onto the computation.
+XlaOp Expm1(const XlaOp& operand);
+
+// Enqueues a floor instruction onto the computation.
+XlaOp Floor(const XlaOp& operand);
+
+// Enqueues a ceil instruction onto the computation.
+XlaOp Ceil(const XlaOp& operand);
+
+// Enqueues a round instruction onto the computation, rounding to nearest even
+// with half-way cases rounding away from zero.
+XlaOp Round(const XlaOp& operand);
+
+// Enqueues an log instruction (natural logarithm) onto the computation.
+XlaOp Log(const XlaOp& operand);
+
+// Enqueues an log1p instruction (log(x+1)) onto the computation.
+XlaOp Log1p(const XlaOp& operand);
+
+// Enqueues a sign instruction onto the computation.
+XlaOp Sign(const XlaOp& operand);
+
+// Enqueues a count leading zeros instruction onto the computation.
+XlaOp Clz(const XlaOp& operand);
+
+// Enqueues a cosine instruction onto the computation.
+XlaOp Cos(const XlaOp& operand);
+
+// Enqueues a sine instruction onto the computation.
+XlaOp Sin(const XlaOp& operand);
+
+// Enqueues a tanh instruction onto the computation.
+XlaOp Tanh(const XlaOp& operand);
+
+// Enqueues a real-part instruction onto the computation.
+XlaOp Real(const XlaOp& operand);
+
+// Enqueues an imaginary-part instruction onto the computation.
+XlaOp Imag(const XlaOp& operand);
+
+// Enqueues a lhs^rhs computation onto the computation.
+XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues an operator that tests if the operand's values are finite, i.e.,
+// not Inf or NaN. Defined only for floating-point types. Returns an array of
+// booleans with the same shape where entries are true iff the corresponding
+// entry was NaN.
+XlaOp IsFinite(const XlaOp& operand);
+
+// Enqueues a convert instruction onto the computation that changes the
+// element type of the operand array to primitive_type.
+XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
+
+// Enqueues a no-op instruction onto the computation that changes
+// the element type of the operand array to primitive_type. The
+// bit-widths of the source and destination element types must be
+// identical.
+XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
+
+// Enqueues a negate instruction onto the computation.
+XlaOp Neg(const XlaOp& operand);
+
+// Enqueues a transpose instruction onto the computation.
+XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+// Enqueues a reverse instruction onto the computation. The order of the
+// elements in the given dimensions is reversed (i.e., the element at index i
+// is moved to index dimension_size - 1 - i).
+XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+
+// Enqueues a sort (as increasing order) instruction onto the computation.
+// If only keys are provided:
+// * If the keys are an rank-1 tensor (an array), the result is a sorted array
+// of keys, in ascending order.
+// * If the keys have higher rank, the keys are sorted along the provided
+// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+// value of 0 will indepenently sort every column, and a dimension value of 1
+// will independently sort each row. If no dimension number is provided, then
+// the last dimension is chosen by default.
+//
+// If both keys and values are provided:
+// * The keys and the values must tensors with the same dimensions. The
+// element types of the tensors may be different.
+// * The result is a tuple that consists of a sorted tensor of keys (along the
+// provided dimension, as above) as the first element, and a tensor with their
+// corresponding values as the second element.
+XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
+
+// Enqueues a clamp instruction onto the computation.
+XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+
+// Enqueues a map instruction onto the computation.
+XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+
+// Enqueues a N(mu, sigma) random number generation instruction onto the
+// computation.
+XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
+
+// Enqueues a U(a, b) random number generation instruction onto the
+// computation. Returns values in the semi-open interval [a, b).
+XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+
+// Enqueues a while node onto the computation.
+XlaOp While(const XlaComputation& condition, const XlaComputation& body,
+ const XlaOp& init);
+
+// Enqueues a conditional node onto the computation.
+XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+
+// Enqueues a ReducePrecision node onto the computation.
+XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+
+// Enqueues a Gather node onto the computation.
+XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
+// Enqueues a Send node onto the computation for device-to-device
+// communication. This operation sends the given operand to
+// a Recv instruction in a different computation that shares the same channel
+// handle.
+void Send(const XlaOp& operand, const ChannelHandle& handle);
+
+// Variant of Send which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+// Enqueues a Recv node onto the computation for device-to-device
+// communication. The data comes from a Send instruction in a different
+// computation that shares the same channel handle and its shape must be the
+// same as the given shape.
+XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Variant of Recv which takes a token-shaped operand and produces a two-element
+// tuple containing the data value and a token-shaped value. Tokens are used
+// for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues a Send node which transfers data from the device to the host. The
+// 'shape_with_layout' argument defines the layout of the data transferred; its
+// shape must be compatible with the shape of the operand. The operand must be
+// array-shaped.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+// Enqueues a Recv node which transfers data from the host to the device. The
+// given shape must contain a layout and must be an array.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues an operation (AfterAll) with no operands that produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// This is a separate method from AfterAll to facility the removal of
+// operand-less AfterAll instructions.
+// TODO(b/110532604): Remove this function when all tokens are derived from a
+// single token generated or passed into the entry computation.
+XlaOp CreateToken(XlaBuilder* builder);
+
+// Enqueues an AfterAll instruction which produces a token-shaped value and
+// takes a variadic number of token-shaped operands. The number of operands must
+// be greater than zero. Used for joining tokens.
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+
+// Normalizes operand across spatial and batch dimensions for each feature.
+//
+// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
+// is the normalized result and batch_mean and batch_var are the mean and
+// variance, respectively, across batch for the operand.
+XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+
+// Normalizes operand across spatial and batch dimensions for each feature.
+//
+// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
+// computing `mean` and `variance` for each batch inside the operation. It
+// uses the input `mean` and `variance` instead as estimated values. The
+// purpose of this op is to reduce latency in inference, hence the name
+// `BatchNormInference`.
+//
+// The output has the same shape as `operand`, and contains the normalized
+// values for each batch.
+XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+
+// Calculates the gradients of a batch norm op.
+//
+// The inputs `batch_mean` and `batch_var` represent the mean and variance
+// across the batch.
+//
+// Returns a tuple of three elements:
+// - grad_operand: Gradient with respect to input `operand`
+// - grad_offset: Gradient with respect to input `offset`
+// - grad_scale: Gradient with respect to input `scale`
+XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+
+// Implementation details below this point.
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR0(NativeT value) {
+ return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+ return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+ literal.PopulateWithValue(value);
+ return ConstantLiteral(literal);
+}
+
+inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
+ return ConstantLiteral(*LiteralUtil::CreateR1(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
+ return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
+ return ConstantLiteral(
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
+ return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
+ return ConstantLiteral(
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) {
+ return ConstantFromArray(values);
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
+ return ConstantFromArrayWithLayout(values, layout);
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
+ return ConstantFromArray(values);
+}
+
+// Free function template implementations.
+
+template <typename NativeT>
+XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+}
+
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+ literal.PopulateWithValue(value);
+ return ConstantLiteral(builder, literal);
+}
+
+inline XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2(XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateFromArray<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values) {
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values) {
+ return ConstantFromArray(builder, values);
+}
+
+template <typename NativeT>
+XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantFromArrayWithLayout(builder, values, layout);
+}
+
+template <typename NativeT>
+XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values) {
+ return ConstantFromArray(builder, values);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index b4a5aedfb1..28a207b137 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include <string>
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index a7168e731b..2e131dbad2 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -25,44 +25,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "xla_builder",
- srcs = ["xla_builder.cc"],
hdrs = ["xla_builder.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:padding",
- "//tensorflow/compiler/xla/client:sharding_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_proto",
- "//tensorflow/compiler/xla/service:shape_inference",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "xla_builder_test",
- srcs = ["xla_builder_test.cc"],
- deps = [
- ":xla_builder",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_matchers",
- "//tensorflow/core:test",
+ "//tensorflow/compiler/xla/client:xla_builder",
],
)
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 980e84e40c..ce2a8afd4c 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -16,2226 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
-#include <map>
-#include <string>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/service/hlo.pb.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/stacktrace.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-
-class XlaBuilder;
-
-// This represents an instruction that has been enqueued using the XlaBuilder.
-// This is used to pass to subsequent computations that depends upon the
-// instruction as an operand.
-class XlaOp {
- public:
- XlaOp() : handle_(-1), builder_(nullptr) {
- static_assert(std::is_trivially_destructible<XlaOp>::value,
- "XlaOp should be trivially destructible");
- }
- ~XlaOp() = default;
-
- // Precondition: !IsUninitialized().
- //
- // It's very common to do foo.builder()->bar(). Without this precondition, if
- // foo.builder() is null, the call to bar will segfault at some point possibly
- // deep in the callstack when we finally dereference `this`. The precondition
- // lets us avoid this tricky-to-debug problem.
- XlaBuilder* builder() const {
- CHECK(builder_ != nullptr);
- return builder_;
- }
-
- // Returns true if the XlaOp represents valid, non-erroneous value.
- bool valid() const { return handle_ >= 0; }
-
- // Returns true if the XlaOp was created by the XlaOp() constructor and
- // not returned by a builder.
- bool IsUninitialized() const { return builder_ == nullptr; }
-
- bool IsIdenticalTo(const XlaOp& rhs) const {
- return handle_ == rhs.handle_ && builder_ == rhs.builder_;
- }
-
- friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
- out << op.handle();
- return out;
- }
-
- private:
- explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
- XlaOp(int64 handle, XlaBuilder* builder)
- : handle_(handle), builder_(builder) {}
-
- int64 handle() const { return handle_; }
-
- friend class XlaBuilder;
-
- // < 0 means "invalid handle".
- int64 handle_;
-
- // Not owned. Non-null for any handle returned by XlaBuilder, even if the
- // handle is invalid.
- XlaBuilder* builder_;
-};
-
-// Arithmetic operator overloads for the XlaOp type.
-XlaOp operator-(const XlaOp& x);
-XlaOp operator+(const XlaOp& x, const XlaOp& y);
-XlaOp operator-(const XlaOp& x, const XlaOp& y);
-XlaOp operator*(const XlaOp& x, const XlaOp& y);
-XlaOp operator/(const XlaOp& x, const XlaOp& y);
-XlaOp operator%(const XlaOp& x, const XlaOp& y);
-
-// Bitwise operator overloads for the XlaOp type.
-XlaOp operator~(const XlaOp& x);
-XlaOp operator&(const XlaOp& x, const XlaOp& y);
-XlaOp operator|(const XlaOp& x, const XlaOp& y);
-XlaOp operator^(const XlaOp& x, const XlaOp& y);
-XlaOp operator<<(const XlaOp& x, const XlaOp& y);
-// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
-// a right logical shift.
-XlaOp operator>>(const XlaOp& x, const XlaOp& y);
-
-// We don't overload the relational operators (==, !=, <, <=, >, >=) because the
-// semantics might be surprising since their result types are usually 'bool'.
-// Further programmers may expect == to be a structural equality.
-// We also choose not to overload any of the mutating operators (e.g., +=, -=)
-// because the semantics might be misleading — XLA computations are immutable.
-
-// A convenient interface for building up computations.
-//
-// Thread-compatible.
-class XlaBuilder {
- public:
- // computation_name: name to use for the built computation.
- XlaBuilder(const string& computation_name);
-
- XlaBuilder(const XlaBuilder&) = delete;
- XlaBuilder& operator=(const XlaBuilder&) = delete;
-
- ~XlaBuilder();
-
- // Returns the computation name.
- const string& name() const { return name_; }
-
- // Sets OpMetadata that will be added to all instructions until cleared.
- //
- // OpMetadata is often applied to a series of XLA HLO instructions. As a
- // result, OpMetadata is set on the Computation Builder. All subsequent
- // instructions generated via this Computation Builder will have the same
- // OpMetadata attached until a call to ClearOpMetadata.
- void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
-
- // Clears the HloMetadata state.
- void ClearOpMetadata() { metadata_.Clear(); }
-
- // Sets an OpSharding that will be attached to all instructions until cleared.
- void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
-
- // Clears the sharding. Ops will be sharded according to the default placement
- // policy.
- void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
-
- // Returns the OpSharding that will be attached to all instructions.
- const tensorflow::gtl::optional<OpSharding>& sharding() const {
- return sharding_;
- }
-
- // Sets the builder to a mode where it will die immediately when an error is
- // encountered, rather than producing it in a deferred fashion when Build() is
- // called (which is the default).
- void set_die_immediately_on_error(bool enabled) {
- die_immediately_on_error_ = enabled;
- }
-
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Returns an error if the convolution dimension numbers have conflicts.
- static Status Validate(const ConvolutionDimensionNumbers& dnum);
-
- // Returns a new XlaBuilder whose resultant Computation is used only by this
- // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
- // behavior as the parent.
- std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status. Note that all ops that have been enqueued will be moved to the
- // computation being returned.
- StatusOr<XlaComputation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent XlaBuilder and returns an empty computation if building failed.
- // This function is intended to be used where the returned XlaComputation is
- // only used by the parent XlaBuilder and hence further operation on the
- // returned XlaComputation will simply be error'ed out if an error occurred
- // while building this computation. If the built computation is to be used by
- // a XlaBuilder other than the parent XlaBuilder then Build() should be used
- // instead.
- XlaComputation BuildAndNoteError();
-
- // Returns a subgraph that roots on the given root. If the root is not a
- // compile-time constant (see `IsConstant`), returns an error.
- //
- // This will copy the needed ops/computations to the subgraph.
- StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // XlaOp and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- // Returns the shape of the given op.
- StatusOr<Shape> GetShape(const XlaOp& op) const;
-
- // Returns the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape() const;
-
- // Reports an error to the builder, by
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to
- // Build()/GetShape()/...)
- // * dying if die_immediately_on_error_ is true.
- // Returns an XlaOp with an invalid handle but a valid builder. This value can
- // be returned in place of a value in APIs that return an XlaOp.
- XlaOp ReportError(const Status& error);
-
- // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
- // If the Status was an error, reports the error to builder and returns an
- // invalid XlaOp handle.
- XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
-
- // A helper function that runs a function that returns a StatusOr<XlaOp> and
- // returns an XlaOp.
- XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
-
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on any parameters, or on stateful operators such
- // as `RngNormal` or `Infeed`.
- //
- // This tests whether a computation is a compile-time constant without
- // evaluating the computation.
- StatusOr<bool> IsConstant(const XlaOp& operand) const;
-
- private:
- // Enqueues a "retrieve parameter value" instruction for a parameter that was
- // passed to the computation.
- XlaOp Parameter(int64 parameter_number, const Shape& shape,
- const string& name);
-
- // Enqueues a constant with the value of the given literal onto the
- // computation.
- XlaOp ConstantLiteral(const LiteralSlice& literal);
-
- // Enqueues a constant onto the computation. Methods are templated on the
- // native host type (NativeT) which corresponds to a specific XLA
- // PrimitiveType as given in the following table:
- //
- // Native Type PrimitiveType
- // -----------------------------
- // bool PRED
- // int32 S32
- // int64 S64
- // uint32 U32
- // uint64 U64
- // float F32
- // double F64
- //
- // Note: not all primitive types defined in xla_data.proto have a
- // corresponding native type yet.
- template <typename NativeT>
- XlaOp ConstantR0(NativeT value);
- template <typename NativeT>
- XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
- XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- XlaOp ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
-
- // Enqueues a rank one constant (vector) onto the computation. The vector has
- // size 'length' and every element has the value 'value'.
- template <typename NativeT>
- XlaOp ConstantR1(int64 length, NativeT value);
-
- // Adds dimensions to an array by duplicating the data in the array.
- //
- // The new dimensions are inserted on the left, i.e. if
- // broadcast_sizes has values {a0, ..., aN} and the operand shape
- // has dimensions {b0, ..., bM} then the shape of the output has
- // dimensions {a0, ..., aN, b0, ..., bM}.
- //
- // The new dimensions index into copies of the operand, i.e.
- //
- // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
- XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
-
- // Performs in-dimension-style broadcast.
- //
- // Operand specifies the input to be broadcast. "shape" is expected output
- // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
- // Dimension numbers in broadcast_dimensions map to individual dimensions
- // of the operand, and specify what dimension of the output shape they
- // should be broadcast.
- // e.g.
- // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
- // and dimension of shape is [2,2].
- // Specifying {1} as brodcast_dimension will generate output
- // [1 , 2]
- // [1 , 2]
- // On the other hand, specifying {0} as broadcast_dimension
- // will generate output
- // [1 , 1]
- // [2 , 2]
- XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- // Enqueues a pad operation onto the computation that pads the given value on
- // the edges as well as between the elements of the input. padding_config
- // specifies the padding amount for each dimension.
- XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
- const PaddingConfig& padding_config);
-
- // Enqueues an operation onto the computation that flattens the operand based
- // on the dimension order (major/slowest-varying to minor/fastest-varying)
- // given, followed by reshaping it into the shape with the given dimension
- // sizes (also major to minor). Conceptually, this is a limited form of
- // "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Enqueues an operation onto the computation that collapses the operand, from
- // first to last dimension (C order), then reshapes it to the given dimension
- // sizes. Conceptually, this is a limited form of "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Wrapper for Reshape.
- // Enqueues an operation to collapse the provided dimensions; e.g. an
- // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
- // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
- // be a consecutive, in-order subsequence of the operand dimensions.
- //
- // Note that collapsing a single dimension does nothing:
- //
- // {256} collapsing {0} => {256}
- // {1} collapsing {0} => {1}
- //
- // Collapsing multiple dimensions produces a single result dimension:
- //
- // {256, 2} collapsing {0,1} => {512}
- // {256, 2, 3} collapsing {0,1} => {512, 3}
- //
- // This could potentially cause data to be moved -- it provides a more
- // structured form of reshaping than an arbitrary Reshape operation.
- XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a slice operation onto the computation that slices the operand
- // from the start indices to the limit indices; e.g.
- //
- // x
- // [ 0 1 2 3 ]
- // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
- // [ 8 9 a b ]
- //
- // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
- // range notation.
- // The strides parameter determines the stride over the slice
- XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
-
- // Enqueues a slice operation in a given dimension, taking all other
- // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
- // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
- // for:
- //
- // array[:, 2:4:1, :]
- XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
-
- // Enqueues a slice operation onto the computation that slices the 'operand'
- // from dynamic start indices which are passed in 'start_indices'.
- // The size of the slice in each dimension is passed in 'slice_sizes',
- // which specify the end point of exclusive slice intervals in each
- // dimension [start, start + size).
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo input dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
-
- // Enqueues a dynamic update slice operation onto the computation, which
- // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
- // The shape of 'update' determines the shape of the slice of 'operand'
- // which is updated.
- // The indices specified in 'start_indices' specify the offset of the slice
- // of 'operand' which is updated.
- //
- // update = {10, 11} // calculated at runtime.
- // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
- // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
- // [7 8 9] [7 8 9 ]
- //
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo update dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
- const XlaOp& start_indices);
-
- // Enqueues a concatenate instruction onto the computation. 'operands' must
- // have >= 1 entry.
- XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
-
- // Enqueue a tracing operation onto the computation; the computation will emit
- // a logging message with the operand.
- void Trace(const string& tag, const XlaOp& operand);
-
- // Enqueues a conditional-move-like select operation onto the computation;
- // predicated on pred, selects between on_true and on_false.
- XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
-
- // Enqueues a tuple-creation instruction onto the computation.
- XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements);
-
- // Enqueues a tuple-element-get instruction onto the computation.
- XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
-
- // Enqueues an equal-to comparison instruction onto the computation.
- XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a not-equal comparison instruction onto the computation.
- XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-or-equal comparison instruction onto the computation.
- XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-than comparison instruction onto the computation.
- XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-than comparison instruction onto the computation.
- XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-or-equal comparison instruction onto the computation.
- XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a dot instruction onto the computation.
- XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
-
- // Enqueues a general dot instruction onto the computation.
- XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, which uses the
- // default convolution dimension numbers.
- XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration in the format returned by MakePadding().
- XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided dimension numbers configuration.
- XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration as well as the dimension numbers.
- XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration, dilation factors and dimension numbers.
- XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues an FFT instruction onto the computation, of the given type and
- // with the given FFT length.
- XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
-
- // Enqueues an infeed instruction onto the computation, which writes data of
- // the given shape to the infeed buffer of the device.
- XlaOp Infeed(const Shape& shape, const string& config = "");
- XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
- const string& config = "");
-
- // Enqueues an outfeed instruction onto the computation. This instruction
- // generates outgoing data transfers for the given data.
- //
- // shape_with_layout communicates the laid out shape that we want to outfeed
- // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
- // will occur.
- void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
- const string& outfeed_config);
- XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout,
- const string& outfeed_config);
-
- // Enqueues a call instruction onto the computation.
- XlaOp Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
-
- // Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
- XlaOp CustomCall(const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
-
- // Enqueues a pseudo-op to represent host-side computation data-dependencies.
- // During code generation, host send and receive operations will be generated
- // to transfer |operands| to the host and a single result of |shape| back to
- // the device. Host send/recv operations are emitted using |channel_name|.
- // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
- // instruction scheduling.
- XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
-
- // The following methods enqueue element-wise binary arithmetic operations
- // onto the computation. The shapes of the operands have to match unless one
- // of the operands is a scalar, or an explicit broadcast dimension is given
- // (see g3doc for more details).
-
- // Enqueues a complex compose instruction onto the computation.
- XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a complex conjugate instruction onto the computation.
- XlaOp Conj(const XlaOp& operand);
-
- // Enqueues an add instruction onto the computation.
- XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a subtract instruction onto the computation.
- XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a multiply instruction onto the computation.
- XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a divide instruction onto the computation.
- XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a remainder instruction onto the computation.
- XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a max instruction onto the computation.
- XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a min instruction onto the computation.
- XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Element-wise logical operators
- XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- XlaOp Not(const XlaOp& operand);
-
- XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Reduces an array among the provided dimensions, given "computation" as a
- // reduction operator.
- XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
-
- // Convenience wrapper around the above that reduces all the dimensions in the
- // operand shape.
- XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation);
-
- // Enqueues a windowed reduce instruction onto the computation.
- XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
-
- // As ReduceWindow(), but the padding is given in the format
- // returned by MakePadding().
- XlaOp ReduceWindowWithGeneralPadding(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Returns the sum of the operand value within each subgroup of replicas. All
- // replicas supply one input to the sum and all replicas receive the resulting
- // sum for each subgroup.
- XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
-
- // Enqueues an operation that do an AllReduce of the operand cross cores. Here
- // AllReduce means doing a reduction on the input operand cross cores and then
- // broadcasting the reduction result to those cores. The reduction function is
- // defined by `computation`, which should be a commutative computation on
- // scalars, e.g., add, min, or max. The way that AllReduce is applied is
- // configured by:
- //
- // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
- // replicas belong to one group. Allreduce will be applied within subgroups.
- // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
- // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
- //
- // - `channel_id`: for Allreduce nodes from different models, if they have the
- // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
- // applied cross models.
- //
- // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
- XlaOp CrossReplicaSum(
- const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
- const tensorflow::gtl::optional<ChannelHandle>& channel_id =
- tensorflow::gtl::nullopt);
-
- // Enqueues an operation that scatters the `source` array to the selected
- // indices of each window.
- XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value,
- const XlaComputation& scatter);
-
- // As SelectAndScatter(), but the padding is given in the format
- // returned by MakePadding().
- XlaOp SelectAndScatterWithGeneralPadding(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
-
- // Enqueues an abs instruction onto the computation.
- XlaOp Abs(const XlaOp& operand);
-
- // Enqueues a atan2 instruction onto the computation.
- XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an exp instruction onto the computation.
- XlaOp Exp(const XlaOp& operand);
-
- // Enqueues an expm1 instruction onto the computation.
- XlaOp Expm1(const XlaOp& operand);
-
- // Enqueues a floor instruction onto the computation.
- XlaOp Floor(const XlaOp& operand);
-
- // Enqueues a ceil instruction onto the computation.
- XlaOp Ceil(const XlaOp& operand);
-
- // Enqueues a round instruction onto the computation, rounding to nearest even
- // with half-way cases rounding away from zero.
- XlaOp Round(const XlaOp& operand);
-
- // Enqueues an log instruction (natural logarithm) onto the computation.
- XlaOp Log(const XlaOp& operand);
-
- // Enqueues an log1p instruction (log(x+1)) onto the computation.
- XlaOp Log1p(const XlaOp& operand);
-
- // Enqueues a sign instruction onto the computation.
- XlaOp Sign(const XlaOp& operand);
-
- // Enqueues a count leading zeros instruction onto the computation.
- XlaOp Clz(const XlaOp& operand);
-
- // Enqueues a cosine instruction onto the computation.
- XlaOp Cos(const XlaOp& operand);
-
- // Enqueues a sine instruction onto the computation.
- XlaOp Sin(const XlaOp& operand);
-
- // Enqueues a tanh instruction onto the computation.
- XlaOp Tanh(const XlaOp& operand);
-
- // Enqueues a real-part instruction onto the computation.
- XlaOp Real(const XlaOp& operand);
-
- // Enqueues an imaginary-part instruction onto the computation.
- XlaOp Imag(const XlaOp& operand);
-
- // Enqueues a lhs^rhs computation onto the computation.
- XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an operator that tests if the operand's values are finite, i.e.,
- // not Inf or NaN. Defined only for floating-point types. Returns an array of
- // booleans with the same shape where entries are true iff the corresponding
- // entry was NaN.
- XlaOp IsFinite(const XlaOp& operand);
-
- // Enqueues a convert instruction onto the computation that changes the
- // element type of the operand array to primitive_type.
- XlaOp ConvertElementType(const XlaOp& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a no-op instruction onto the computation that changes
- // the element type of the operand array to primitive_type. The
- // bit-widths of the source and destination element types must be
- // identical.
- XlaOp BitcastConvertType(const XlaOp& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a negate instruction onto the computation.
- XlaOp Neg(const XlaOp& operand);
-
- // Enqueues a transpose instruction onto the computation.
- XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
-
- // Enqueues a reverse instruction onto the computation. The order of the
- // elements in the given dimensions is reversed (i.e., the element at index i
- // is moved to index dimension_size - 1 - i).
- XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a sort (as increasing order) instruction onto the computation.
- // If only keys are provided:
- // * If the keys are an rank-1 tensor (an array), the result is a sorted array
- // of keys, in ascending order.
- // * If the keys have higher rank, the keys are sorted along the provided
- // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
- // value of 0 will indepenently sort every column, and a dimension value of 1
- // will independently sort each row. If no dimension number is provided, then
- // the last dimension is chosen by default.
- //
- // If both keys and values are provided:
- // * The keys and the values must tensors with the same dimensions. The
- // element types of the tensors may be different.
- // * The result is a tuple that consists of a sorted tensor of keys (along the
- // provided dimension, as above) as the first element, and a tensor with their
- // corresponding values as the second element.
- XlaOp Sort(XlaOp keys,
- tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
- int64 dimension = -1);
-
- // Enqueues a clamp instruction onto the computation.
- XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
-
- // Enqueues a map instruction onto the computation.
- XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
-
- // Enqueues a N(mu, sigma) random number generation instruction onto the
- // computation.
- XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
-
- // Enqueues a U(a, b) random number generation instruction onto the
- // computation. Returns values in the semi-open interval [a, b).
- XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
-
- // Enqueues a while node onto the computation.
- XlaOp While(const XlaComputation& condition, const XlaComputation& body,
- const XlaOp& init);
-
- // Enqueues a conditional node onto the computation.
- XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
- const XlaComputation& true_computation,
- const XlaOp& false_operand,
- const XlaComputation& false_computation);
-
- // Enqueues a ReducePrecision node onto the computation.
- XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
- const int mantissa_bits);
-
- // Enqueues a Gather node onto the computation.
- XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
-
- // Enqueues a Send node onto the computation for device-to-device
- // communication, to send the given operand to a Recv instruction that shares
- // the same channel handle.
- void Send(const XlaOp& operand, const ChannelHandle& handle);
- XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
- const ChannelHandle& handle);
-
- // Enqueues a Send node which sends data to the host.
- XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout, const ChannelHandle& handle);
-
- // Enqueues a Recv node which receives data from the host.
- XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
-
- // Enqueues an AfterAll operation with no operands producing a token-shaped
- // value.
- XlaOp CreateToken();
-
- // Enqueues an AfterAll operation with no operands producing a token-shaped
- // value.
- XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
-
- // Enqueues a Recv node onto the computation. The data comes from a Send
- // instruction that shares the same channel handle and its shape must
- // be the same as the given shape.
- XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
- XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
- // is the normalized result and batch_mean and batch_var are the mean and
- // variance, respectively, across batch for the operand.
- XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, float epsilon,
- int64 feature_index);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
- // computing `mean` and `variance` for each batch inside the operation. It
- // uses the input `mean` and `variance` instead as estimated values. The
- // purpose of this op is to reduce latency in inference, hence the name
- // `BatchNormInference`.
- //
- // The output has the same shape as `operand`, and contains the normalized
- // values for each batch.
- XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, const XlaOp& mean,
- const XlaOp& variance, float epsilon,
- int64 feature_index);
-
- // Calculates the gradients of a batch norm op.
- //
- // The inputs `batch_mean` and `batch_var` represent the mean and variance
- // across the batch.
- //
- // Returns a tuple of three elements:
- // - grad_operand: Gradient with respect to input `operand`
- // - grad_offset: Gradient with respect to input `offset`
- // - grad_scale: Gradient with respect to input `scale`
- XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& batch_mean, const XlaOp& batch_var,
- const XlaOp& grad_output, float epsilon,
- int64 feature_index);
-
- StatusOr<XlaOp> AddInstruction(
- HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands = {});
-
- void AddCalledComputation(const XlaComputation& computation,
- HloInstructionProto* instr);
-
- StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
-
- // Internal helper method that does the building for an arbitrary unary op.
- XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
-
- // Internal helper method that does the building for an arbitrary binary op.
- // broadcast_dimensions specifies which dimensions to use for broadcasting
- // when the operation is between tensors of different ranks.
- XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- // Internal helper method that does the building for an arbitrary ternary op.
- XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
- const XlaOp& ehs);
-
- XlaOp RngOp(RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<XlaOp> parameters,
- const Shape& shape);
-
- StatusOr<XlaOp> InDimBroadcast(
- const Shape& shape, const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- // Internal helper method that creates a sequence of instructions that
- // performs an explicit broadcast of the operand to the target shape.
- StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
- const XlaOp& operand);
-
- // Internal helper method for creating a Reshape op with the already inferred
- // shape.
- StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
-
- // Returns the (inferred) result for the program shape for the current
- // computation and fills the root_id in the pointer.
- StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
-
- // Returns shapes for the operands.
- StatusOr<std::vector<Shape>> GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) const;
-
- // A visitor which checks whether an operation is a compile-time constant,
- // meaning that it doesn't depend on any parameters, or on any stateful
- // operation such as `RngNormal` or `Infeed`. The visitor walks the
- // computation starting at a given operation and sets is_constant to false iff
- // a parameter or stateful operation is encountered.
- void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
- bool* is_constant) const;
-
- // Checks bounds for convolution parameters.
- Status VerifyConvolution(
- const Shape& lhs_shape, const Shape& rhs_shape,
- const ConvolutionDimensionNumbers& dimension_numbers) const;
-
- // Helper function for creating a Window proto from user-supplied data.
- // Returns error if the user-supplied data was invalid.
- StatusOr<Window> MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation) const;
-
- string name_; // Name to use for the built computation.
-
- // The first error encountered while building the computation.
- // This is OK until the first error is encountered.
- Status first_error_;
-
- // The saved stack trace from the point at which the first error occurred.
- tensorflow::SavedStackTrace first_error_backtrace_;
-
- // The instructions of this computation.
- std::vector<HloInstructionProto> instructions_;
-
- // The embedded computations used by this computation. Each computation was
- // the entry computation of some XlaComputation, the key is the unique id of
- // that XlaComputation.
- std::map<int64, HloComputationProto> embedded_;
-
- // The unique parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers_;
-
- // The metadata to attach to each op. This is structured as a "modal"-like
- // operation, in order to simplify client code (and not sprinkle this metadata
- // throughout the TensorFlow op kernel implementations).
- OpMetadata metadata_;
-
- // Sharding for this operator. This is structured as a "model"-like operation,
- // in order to simplify client code, similar to metadata_.
- tensorflow::gtl::optional<OpSharding> sharding_;
-
- // Mode bit that indicates whether to die when a first error is encountered.
- bool die_immediately_on_error_ = false;
-
- XlaBuilder* parent_builder_{nullptr};
-
- friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
- const Shape& shape, const string& name);
- friend XlaOp ConstantLiteral(XlaBuilder* builder,
- const LiteralSlice& literal);
- template <typename NativeT>
- friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
- template <typename NativeT>
- friend XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
- friend XlaOp ConstantR1(XlaBuilder* builder,
- const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- friend XlaOp ConstantR2(
- XlaBuilder* builder,
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
- const Array<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantFromArray(XlaBuilder* builder,
- const Array<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
- const Array2D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
- const Array2D<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
- const Array3D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
- const Array3D<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
- const Array4D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
- const Array4D<NativeT>& values);
-
- template <typename NativeT>
- friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
-
- friend XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
-
- friend XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
- const PaddingConfig& padding_config);
-
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- friend XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- friend XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
-
- friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
- int64 limit_index, int64 stride, int64 dimno);
-
- friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
-
- friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
- const XlaOp& start_indices);
-
- friend XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
-
- friend void Trace(const string& tag, const XlaOp& operand);
-
- friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
- const XlaOp& on_false);
- friend XlaOp Tuple(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> elements);
- friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
- friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
- friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
- friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
- friend XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- friend XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
- friend XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
- friend XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
- friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
- friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
- const string& config);
- friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
- const string& outfeed_config);
- friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
- friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
- friend XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
- friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Conj(const XlaOp& operand);
- friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Not(const XlaOp& operand);
- friend XlaOp ShiftLeft(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
- friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation);
- friend XlaOp ReduceWindow(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
- friend XlaOp ReduceWindowWithGeneralPadding(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- const tensorflow::gtl::optional<ChannelHandle>& channel_id);
- friend XlaOp SelectAndScatter(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
- friend XlaOp SelectAndScatterWithGeneralPadding(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
- friend XlaOp Abs(const XlaOp& operand);
- friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Exp(const XlaOp& operand);
- friend XlaOp Expm1(const XlaOp& operand);
- friend XlaOp Floor(const XlaOp& operand);
- friend XlaOp Ceil(const XlaOp& operand);
- friend XlaOp Round(const XlaOp& operand);
- friend XlaOp Log(const XlaOp& operand);
- friend XlaOp Log1p(const XlaOp& operand);
- friend XlaOp Sign(const XlaOp& operand);
- friend XlaOp Clz(const XlaOp& operand);
- friend XlaOp Cos(const XlaOp& operand);
- friend XlaOp Sin(const XlaOp& operand);
- friend XlaOp Tanh(const XlaOp& operand);
- friend XlaOp Real(const XlaOp& operand);
- friend XlaOp Imag(const XlaOp& operand);
- friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp IsFinite(const XlaOp& operand);
- // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
- // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
- friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
- friend XlaOp ConvertElementType(const XlaOp& operand,
- PrimitiveType new_element_type);
- friend XlaOp BitcastConvertType(const XlaOp& operand,
- PrimitiveType new_element_type);
- friend XlaOp Neg(const XlaOp& operand);
- friend XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
- friend XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
- friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
- int64 dimension);
- friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
- friend XlaOp Map(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands);
- friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
- const Shape& shape);
- friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
- friend XlaOp While(const XlaComputation& condition,
- const XlaComputation& body, const XlaOp& init);
- friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
- const XlaComputation& true_computation,
- const XlaOp& false_operand,
- const XlaComputation& false_computation);
- friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
- const int mantissa_bits);
- friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
- friend void Send(const XlaOp& operand, const ChannelHandle& handle);
- friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
- const ChannelHandle& handle);
- friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, float epsilon,
- int64 feature_index);
- friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, const XlaOp& mean,
- const XlaOp& variance, float epsilon,
- int64 feature_index);
- friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& batch_mean, const XlaOp& batch_var,
- const XlaOp& grad_output, float epsilon,
- int64 feature_index);
- friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
- const ChannelHandle& handle);
- friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
- friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout,
- const ChannelHandle& handle);
- friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
- friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
- const string& config);
- friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout,
- const string& outfeed_config);
- friend XlaOp CreateToken(XlaBuilder* builder);
- friend XlaOp AfterAll(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> tokens);
-};
-
-// RAII-style object: sets the current sharding assignment in builder on
-// construction, and sets back to the previous assignment on destruction.
-class XlaScopedShardingAssignment {
- public:
- XlaScopedShardingAssignment(xla::XlaBuilder* builder,
- tensorflow::gtl::optional<OpSharding> sharding)
- : builder_(builder), prev_sharding_(builder->sharding()) {
- SetSharding(sharding);
- }
-
- XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
- XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
- delete;
-
- ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
-
- private:
- void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
- if (sharding.has_value()) {
- builder_->SetSharding(sharding.value());
- } else {
- builder_->ClearSharding();
- }
- }
-
- xla::XlaBuilder* const builder_;
- tensorflow::gtl::optional<OpSharding> prev_sharding_;
-};
-
-// Free functions for building XlaOps. The intention is that these will
-// become the public API for building XlaOps rather than calling methods on
-// XlaBuilder directly.
-
-// Enqueues a "retrieve parameter value" instruction for a parameter that was
-// passed to the computation.
-XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
- const string& name);
-
-// Enqueues a constant with the value of the given literal onto the
-// computation.
-XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
-
-// Enqueues a constant onto the computation. Methods are templated on the
-// native host type (NativeT) which corresponds to a specific XLA
-// PrimitiveType as given in the following table:
-//
-// Native Type PrimitiveType
-// -----------------------------
-// bool PRED
-// int32 S32
-// int64 S64
-// uint32 U32
-// uint64 U64
-// float F32
-// double F64
-//
-// Note: not all primitive types defined in xla_data.proto have a
-// corresponding native type yet.
-template <typename NativeT>
-XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
-template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
-XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
-template <typename NativeT>
-XlaOp ConstantR2(XlaBuilder* builder,
- std::initializer_list<std::initializer_list<NativeT>> values);
-template <typename NativeT>
-XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
- const Array<NativeT>& values,
- const Layout& layout);
-template <typename NativeT>
-XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
-template <typename NativeT>
-XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
- const Array2D<NativeT>& values,
- const Layout& layout);
-template <typename NativeT>
-XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
- const Array2D<NativeT>& values);
-template <typename NativeT>
-XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
- const Array3D<NativeT>& values,
- const Layout& layout);
-template <typename NativeT>
-XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
- const Array3D<NativeT>& values);
-template <typename NativeT>
-XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
- const Array4D<NativeT>& values,
- const Layout& layout);
-template <typename NativeT>
-XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
- const Array4D<NativeT>& values);
-
-// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
-// computation. The vector has size 'length' and every element has the value
-// 'value'.
-template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
-
-// Adds dimensions to an array by duplicating the data in the array.
-//
-// The new dimensions are inserted on the left, i.e. if
-// broadcast_sizes has values {a0, ..., aN} and the operand shape
-// has dimensions {b0, ..., bM} then the shape of the output has
-// dimensions {a0, ..., aN, b0, ..., bM}.
-//
-// The new dimensions index into copies of the operand, i.e.
-//
-// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
-XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
-
-// Performs in-dimension-style broadcast.
-//
-// Operand specifies the input to be broadcast. "shape" is expected output
-// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
-// Dimension numbers in broadcast_dimensions map to individual dimensions
-// of the operand, and specify what dimension of the output shape they
-// should be broadcast.
-// e.g.
-// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
-// and dimension of shape is [2,2].
-// Specifying {1} as brodcast_dimension will generate output
-// [1 , 2]
-// [1 , 2]
-// On the other hand, specifying {0} as broadcast_dimension
-// will generate output
-// [1 , 1]
-// [2 , 2]
-XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
-// Enqueues a pad operation onto the computation that pads the given value on
-// the edges as well as between the elements of the input. padding_config
-// specifies the padding amount for each dimension.
-XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
- const PaddingConfig& padding_config);
-
-// Enqueues an operation onto the computation that flattens the operand based
-// on the dimension order (major/slowest-varying to minor/fastest-varying)
-// given, followed by reshaping it into the shape with the given dimension
-// sizes (also major to minor). Conceptually, this is a limited form of
-// "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
-// Enqueues an operation onto the computation that collapses the operand, from
-// first to last dimension (C order), then reshapes it to the given dimension
-// sizes. Conceptually, this is a limited form of "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
-// Wrapper for Reshape.
-// Enqueues an operation to collapse the provided dimensions; e.g. an
-// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
-// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
-// be a consecutive, in-order subsequence of the operand dimensions.
-//
-// Note that collapsing a single dimension does nothing:
-//
-// {256} collapsing {0} => {256}
-// {1} collapsing {0} => {1}
-//
-// Collapsing multiple dimensions produces a single result dimension:
-//
-// {256, 2} collapsing {0,1} => {512}
-// {256, 2, 3} collapsing {0,1} => {512, 3}
-//
-// This could potentially cause data to be moved -- it provides a more
-// structured form of reshaping than an arbitrary Reshape operation.
-XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
-// Enqueues a slice operation onto the computation that slices the operand
-// from the start indices to the limit indices; e.g.
-//
-// x
-// [ 0 1 2 3 ]
-// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
-// [ 8 9 a b ]
-//
-// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
-// range notation.
-// The strides parameter determines the stride over the slice
-XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
-
-// Enqueues a slice operation in a given dimension, taking all other
-// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
-// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
-// for:
-//
-// array[:, 2:4:1, :]
-XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
-
-// Enqueues a slice operation onto the computation that slices the 'operand'
-// from dynamic start indices which are passed in 'start_indices'.
-// The size of the slice in each dimension is passed in 'slice_sizes',
-// which specify the end point of exclusive slice intervals in each
-// dimension [start, start + size).
-// The shape of 'start_indices' must be rank == 1, with dimension size
-// equal to the rank of the 'operand'.
-// Slice index calculations are computed modulo input dimension sizes to
-// prevent dynamic start indices from generating out-of-bound array accesses.
-XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
-
-// Enqueues a dynamic update slice operation onto the computation, which
-// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
-// The shape of 'update' determines the shape of the slice of 'operand'
-// which is updated.
-// The indices specified in 'start_indices' specify the offset of the slice
-// of 'operand' which is updated.
-//
-// update = {10, 11} // calculated at runtime.
-// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
-// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
-// [7 8 9] [7 8 9 ]
-//
-// The shape of 'start_indices' must be rank == 1, with dimension size
-// equal to the rank of the 'operand'.
-// Slice index calculations are computed modulo update dimension sizes to
-// prevent dynamic start indices from generating out-of-bound array accesses.
-XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
- const XlaOp& start_indices);
-
-// Enqueues a concatenate instruction onto the computation. 'operands' must
-// have >= 1 entry.
-XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
-
-// Enqueue a tracing operation onto the computation; the computation will emit
-// a logging message with the operand.
-void Trace(const string& tag, const XlaOp& operand);
-
-// Enqueues a conditional-move-like select operation onto the computation;
-// predicated on pred, selects between on_true and on_false.
-XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
-
-// Enqueues a tuple-creation instruction onto the computation.
-XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
-
-// Enqueues a tuple-element-get instruction onto the computation.
-XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
-
-// Enqueues an equal-to comparison instruction onto the computation.
-XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a not-equal comparison instruction onto the computation.
-XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a greater-or-equal comparison instruction onto the computation.
-XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a greater-than comparison instruction onto the computation.
-XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a less-than comparison instruction onto the computation.
-XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a less-or-equal comparison instruction onto the computation.
-XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a dot instruction onto the computation.
-XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
-
-// Enqueues a general dot instruction onto the computation.
-XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
-
-// Enqueues a convolution instruction onto the computation, which uses the
-// default convolution dimension numbers.
-XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
-
-// Enqueues a convolution instruction onto the computation, with the caller
-// provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
-// Enqueues a convolution instruction onto the computation, with the caller
-// provided dimension numbers configuration.
-XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
-// Enqueues a convolution instruction onto the computation, with the caller
-// provided padding configuration as well as the dimension numbers.
-XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
-// Enqueues a convolution instruction onto the computation, with the caller
-// provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
-// Enqueues an FFT instruction onto the computation, of the given type and
-// with the given FFT length.
-XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
-
-// Enqueues an infeed instruction onto the computation, which writes data of
-// the given shape to the infeed buffer of the device.
-XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
- const string& config = "");
-
-// Variant of Infeed which takes a token-shaped operand and produces a
-// two-element tuple containing the data value and a token-shaped value.
-// Tokens are used for ordering side-effecting operations.
-// TODO(b/110532604): Replace all uses of the non-token form with this variant.
-XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
- const string& config = "");
-
-// Enqueues an outfeed instruction onto the computation. This instruction
-// generates outgoing data transfers for the given data.
-//
-// shape_with_layout communicates the laid out shape that we want to outfeed
-// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
-// will occur.
-void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
- const string& outfeed_config);
-
-// Variant of Outfeed which takes a token-shaped operand and produces a
-// token-shaped value. Tokens are used for ordering side-effecting operations.
-// TODO(b/110532604): Replace all uses of the non-token form with this variant.
-XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout,
- const string& outfeed_config);
-
-// Enqueues a call instruction onto the computation.
-XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
-
-// Enqueues a custom call instruction onto the computation.
-// During code generation, a call instruction is emitted which targets a
-// symbol with the name |call_target_name|. The |operands| are passed to the
-// call instruction. |shape| is the resultant shape.
-XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
-
-// Enqueues a pseudo-op to represent host-side computation data-dependencies.
-// During code generation, host send and receive operations will be generated
-// to transfer |operands| to the host and a single result of |shape| back to
-// the device. Host send/recv operations are emitted using |channel_name|.
-// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
-// instruction scheduling.
-XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
-
-// The following methods enqueue element-wise binary arithmetic operations
-// onto the computation. The shapes of the operands have to match unless one
-// of the operands is a scalar, or an explicit broadcast dimension is given
-// (see g3doc for more details).
-
-// Enqueues a complex compose instruction onto the computation.
-XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a complex conjugate instruction onto the computation.
-XlaOp Conj(const XlaOp& operand);
-
-// Enqueues an add instruction onto the computation.
-XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a subtract instruction onto the computation.
-XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a multiply instruction onto the computation.
-XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a divide instruction onto the computation.
-XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a remainder instruction onto the computation.
-XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a max instruction onto the computation.
-XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues a min instruction onto the computation.
-XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Element-wise logical operators
-XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-XlaOp Not(const XlaOp& operand);
-
-XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Reduces an array among the provided dimensions, given "computation" as a
-// reduction operator.
-XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
-
-// Convenience wrapper around the above that reduces all the dimensions in the
-// operand shape.
-XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation);
-
-// Enqueues a windowed reduce instruction onto the computation.
-XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
-
-// As ReduceWindow(), but the padding is given in the format
-// returned by MakePadding().
-XlaOp ReduceWindowWithGeneralPadding(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
-// Returns the sum of the operand value within each subgroup of replicas. All
-// replicas supply one input to the sum and all replicas receive the resulting
-// sum for each subgroup.
-XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
-
-// Enqueues an operation that do an AllReduce of the operand cross cores. Here
-// AllReduce means doing a reduction on the input operand cross cores and then
-// broadcasting the reduction result to those cores. The reduction function is
-// defined by `computation`, which should be a commutative computation on
-// scalars, e.g., add, min, or max. The way that AllReduce is applied is
-// configured by:
-//
-// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
-// replicas belong to one group. Allreduce will be applied within subgroups.
-// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
-// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
-//
-// - `channel_id`: for Allreduce nodes from different models, if they have the
-// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
-// applied cross models.
-//
-// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
-XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
- const tensorflow::gtl::optional<ChannelHandle>&
- channel_id = tensorflow::gtl::nullopt);
-
-// Enqueues an operation that scatters the `source` array to the selected
-// indices of each window.
-XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value, const XlaComputation& scatter);
-
-// As SelectAndScatter(), but the padding is given in the format
-// returned by MakePadding().
-XlaOp SelectAndScatterWithGeneralPadding(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
-
-// Enqueues an abs instruction onto the computation.
-XlaOp Abs(const XlaOp& operand);
-
-// Enqueues a atan2 instruction onto the computation.
-XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues an exp instruction onto the computation.
-XlaOp Exp(const XlaOp& operand);
-
-// Enqueues an expm1 instruction onto the computation.
-XlaOp Expm1(const XlaOp& operand);
-
-// Enqueues a floor instruction onto the computation.
-XlaOp Floor(const XlaOp& operand);
-
-// Enqueues a ceil instruction onto the computation.
-XlaOp Ceil(const XlaOp& operand);
-
-// Enqueues a round instruction onto the computation, rounding to nearest even
-// with half-way cases rounding away from zero.
-XlaOp Round(const XlaOp& operand);
-
-// Enqueues an log instruction (natural logarithm) onto the computation.
-XlaOp Log(const XlaOp& operand);
-
-// Enqueues an log1p instruction (log(x+1)) onto the computation.
-XlaOp Log1p(const XlaOp& operand);
-
-// Enqueues a sign instruction onto the computation.
-XlaOp Sign(const XlaOp& operand);
-
-// Enqueues a count leading zeros instruction onto the computation.
-XlaOp Clz(const XlaOp& operand);
-
-// Enqueues a cosine instruction onto the computation.
-XlaOp Cos(const XlaOp& operand);
-
-// Enqueues a sine instruction onto the computation.
-XlaOp Sin(const XlaOp& operand);
-
-// Enqueues a tanh instruction onto the computation.
-XlaOp Tanh(const XlaOp& operand);
-
-// Enqueues a real-part instruction onto the computation.
-XlaOp Real(const XlaOp& operand);
-
-// Enqueues an imaginary-part instruction onto the computation.
-XlaOp Imag(const XlaOp& operand);
-
-// Enqueues a lhs^rhs computation onto the computation.
-XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
-// Enqueues an operator that tests if the operand's values are finite, i.e.,
-// not Inf or NaN. Defined only for floating-point types. Returns an array of
-// booleans with the same shape where entries are true iff the corresponding
-// entry was NaN.
-XlaOp IsFinite(const XlaOp& operand);
-
-// Enqueues a convert instruction onto the computation that changes the
-// element type of the operand array to primitive_type.
-XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
-
-// Enqueues a no-op instruction onto the computation that changes
-// the element type of the operand array to primitive_type. The
-// bit-widths of the source and destination element types must be
-// identical.
-XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
-
-// Enqueues a negate instruction onto the computation.
-XlaOp Neg(const XlaOp& operand);
-
-// Enqueues a transpose instruction onto the computation.
-XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
-
-// Enqueues a reverse instruction onto the computation. The order of the
-// elements in the given dimensions is reversed (i.e., the element at index i
-// is moved to index dimension_size - 1 - i).
-XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
-
-// Enqueues a sort (as increasing order) instruction onto the computation.
-// If only keys are provided:
-// * If the keys are an rank-1 tensor (an array), the result is a sorted array
-// of keys, in ascending order.
-// * If the keys have higher rank, the keys are sorted along the provided
-// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
-// value of 0 will indepenently sort every column, and a dimension value of 1
-// will independently sort each row. If no dimension number is provided, then
-// the last dimension is chosen by default.
-//
-// If both keys and values are provided:
-// * The keys and the values must tensors with the same dimensions. The
-// element types of the tensors may be different.
-// * The result is a tuple that consists of a sorted tensor of keys (along the
-// provided dimension, as above) as the first element, and a tensor with their
-// corresponding values as the second element.
-XlaOp Sort(XlaOp keys,
- tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
- int64 dimension = -1);
-
-// Enqueues a clamp instruction onto the computation.
-XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
-
-// Enqueues a map instruction onto the computation.
-XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
-
-// Enqueues a N(mu, sigma) random number generation instruction onto the
-// computation.
-XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
-
-// Enqueues a U(a, b) random number generation instruction onto the
-// computation. Returns values in the semi-open interval [a, b).
-XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
-
-// Enqueues a while node onto the computation.
-XlaOp While(const XlaComputation& condition, const XlaComputation& body,
- const XlaOp& init);
-
-// Enqueues a conditional node onto the computation.
-XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
- const XlaComputation& true_computation,
- const XlaOp& false_operand,
- const XlaComputation& false_computation);
-
-// Enqueues a ReducePrecision node onto the computation.
-XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
- const int mantissa_bits);
-
-// Enqueues a Gather node onto the computation.
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
-
-// Enqueues a Send node onto the computation for device-to-device
-// communication. This operation sends the given operand to
-// a Recv instruction in a different computation that shares the same channel
-// handle.
-void Send(const XlaOp& operand, const ChannelHandle& handle);
-
-// Variant of Send which takes a token-shaped operand and produces a
-// token-shaped value. Tokens are used for ordering side-effecting operations.
-// TODO(b/110532604): Replace all uses of the non-token form with this variant.
-XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
- const ChannelHandle& handle);
-
-// Enqueues a Recv node onto the computation for device-to-device
-// communication. The data comes from a Send instruction in a different
-// computation that shares the same channel handle and its shape must be the
-// same as the given shape.
-XlaOp Recv(XlaBuilder* builder, const Shape& shape,
- const ChannelHandle& handle);
-
-// Variant of Recv which takes a token-shaped operand and produces a two-element
-// tuple containing the data value and a token-shaped value. Tokens are used
-// for ordering side-effecting operations.
-// TODO(b/110532604): Replace all uses of the non-token form with this variant.
-XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
-
-// Enqueues a Send node which transfers data from the device to the host. The
-// 'shape_with_layout' argument defines the layout of the data transferred; its
-// shape must be compatible with the shape of the operand. The operand must be
-// array-shaped.
-// TODO(b/111544877): Support tuple shapes.
-XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
- const Shape& shape_with_layout, const ChannelHandle& handle);
-
-// Enqueues a Recv node which transfers data from the host to the device. The
-// given shape must contain a layout and must be an array.
-// TODO(b/111544877): Support tuple shapes.
-XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle);
-
-// Enqueues an operation (AfterAll) with no operands that produces a
-// token-shaped value. Tokens are used for ordering side-effecting operations.
-// This is a separate method from AfterAll to facility the removal of
-// operand-less AfterAll instructions.
-// TODO(b/110532604): Remove this function when all tokens are derived from a
-// single token generated or passed into the entry computation.
-XlaOp CreateToken(XlaBuilder* builder);
-
-// Enqueues an AfterAll instruction which produces a token-shaped value and
-// takes a variadic number of token-shaped operands. The number of operands must
-// be greater than zero. Used for joining tokens.
-XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
-
-// Normalizes operand across spatial and batch dimensions for each feature.
-//
-// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
-// is the normalized result and batch_mean and batch_var are the mean and
-// variance, respectively, across batch for the operand.
-XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, float epsilon,
- int64 feature_index);
-
-// Normalizes operand across spatial and batch dimensions for each feature.
-//
-// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
-// computing `mean` and `variance` for each batch inside the operation. It
-// uses the input `mean` and `variance` instead as estimated values. The
-// purpose of this op is to reduce latency in inference, hence the name
-// `BatchNormInference`.
-//
-// The output has the same shape as `operand`, and contains the normalized
-// values for each batch.
-XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& offset, const XlaOp& mean,
- const XlaOp& variance, float epsilon,
- int64 feature_index);
-
-// Calculates the gradients of a batch norm op.
-//
-// The inputs `batch_mean` and `batch_var` represent the mean and variance
-// across the batch.
-//
-// Returns a tuple of three elements:
-// - grad_operand: Gradient with respect to input `operand`
-// - grad_offset: Gradient with respect to input `offset`
-// - grad_scale: Gradient with respect to input `scale`
-XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
- const XlaOp& batch_mean, const XlaOp& batch_var,
- const XlaOp& grad_output, float epsilon,
- int64 feature_index);
-
-// Implementation details below this point.
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
- Literal literal(ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
- literal.PopulateWithValue(value);
- return ConstantLiteral(literal);
-}
-
-inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout) {
- return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-// Free function template implementations.
-
-template <typename NativeT>
-XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
-}
-
-template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
- Literal literal(ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
- literal.PopulateWithValue(value);
- return ConstantLiteral(builder, literal);
-}
-
-inline XlaOp ConstantR1(XlaBuilder* builder,
- const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
-}
-
-template <typename NativeT>
-XlaOp ConstantR2(XlaBuilder* builder,
- std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
- const Array<NativeT>& values,
- const Layout& layout) {
- return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
- return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
- const Array2D<NativeT>& values,
- const Layout& layout) {
- return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
- const Array2D<NativeT>& values) {
- return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
- const Array3D<NativeT>& values,
- const Layout& layout) {
- return ConstantLiteral(
- builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
- const Array3D<NativeT>& values) {
- return ConstantFromArray(builder, values);
-}
-
-template <typename NativeT>
-XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
- const Array4D<NativeT>& values,
- const Layout& layout) {
- return ConstantFromArrayWithLayout(builder, values, layout);
-}
-
-template <typename NativeT>
-XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
- const Array4D<NativeT>& values) {
- return ConstantFromArray(builder, values);
-}
-
-} // namespace xla
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index e26e35eb11..c8f2d65c22 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -53,9 +53,9 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index fbcf0f1969..434d78d78d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 57da7e53d5..545aa63f9d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 6397f1f479..a803520876 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <array>
#include <utility>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 0b1cec1925..44b22a5586 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -56,7 +56,7 @@ tf_cc_test(
":grpc_stub",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 90efee50b4..6788676181 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "grpcpp/security/credentials.h"
#include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/rpc/grpc_stub.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/lib/io/path.h"
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2305dd4318..528b7fdfd3 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -256,7 +256,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -564,7 +564,7 @@ cc_library(
":computation_placer",
":device_memory_allocator",
":platform_util",
- ":pool",
+ ":stream_pool",
":transfer_manager",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -598,6 +598,7 @@ cc_library(
":hlo_proto_util",
":platform_util",
":source_map_util",
+ ":stream_pool",
":transfer_manager",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
@@ -751,8 +752,8 @@ cc_library(
":hlo_execution_profile",
":hlo_graph_dumper",
":hlo_proto",
- ":pool",
":shaped_buffer",
+ ":stream_pool",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -838,7 +839,7 @@ cc_library(
hdrs = ["execution_tracker.h"],
deps = [
":backend",
- ":pool",
+ ":stream_pool",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -946,7 +947,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
@@ -1663,8 +1663,8 @@ tf_cc_test(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -2670,7 +2670,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -2707,7 +2707,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -2715,21 +2715,25 @@ tf_cc_test(
)
cc_library(
- name = "pool",
- hdrs = ["pool.h"],
+ name = "stream_pool",
+ srcs = ["stream_pool.cc"],
+ hdrs = ["stream_pool.h"],
deps = [
+ "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
],
)
tf_cc_test(
- name = "pool_test",
- srcs = ["pool_test.cc"],
+ name = "stream_pool_test",
+ srcs = ["stream_pool_test.cc"],
deps = [
- ":pool",
+ ":stream_pool",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:stream_executor_no_cuda",
],
)
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 349b32451a..d12be3e007 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -96,24 +96,19 @@ Backend::CreateDefaultBackend() {
return CreateBackend(backend_options);
}
-StatusOr<Backend::StreamPtr> Backend::BorrowStream(int device_ordinal) {
- TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal));
- return BorrowStream(exec);
+StatusOr<StreamPool::Ptr> Backend::BorrowStream(int device_ordinal) {
+ TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal));
+ return BorrowStream(executor);
}
-StatusOr<Backend::StreamPtr> Backend::BorrowStream(
- se::StreamExecutor* executor) {
+StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
tensorflow::mutex_lock l(mu_);
if (0 == stream_pools_.count(executor)) {
stream_pools_.emplace(std::piecewise_construct,
std::forward_as_tuple(executor),
- std::forward_as_tuple([executor]() {
- auto stream = MakeUnique<se::Stream>(executor);
- stream->Init();
- return stream;
- }));
+ std::forward_as_tuple());
}
- return stream_pools_.at(executor).Allocate();
+ return stream_pools_.at(executor).BorrowStream(executor);
}
Backend::Backend(
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 6546602473..1bc3796fa4 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
-#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -63,11 +63,9 @@ class BackendOptions {
//
// It also offers a pooling API for creation/use of initialized streams:
//
-// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie();
+// StreamPool::Ptr stream = backend->BorrowStream().ConsumeValueOrDie();
class Backend {
public:
- using StreamPtr = Pool<se::Stream>::SmartPtr;
-
// Creates a new backend.
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
const BackendOptions& options);
@@ -114,13 +112,13 @@ class Backend {
// Borrows a stream for use by the caller, either by grabbing it from an
// internal pool, or by constructing/initializating it, and returns the result
// to the caller.
- StatusOr<StreamPtr> BorrowStream(int device_ordinal);
- StatusOr<StreamPtr> BorrowStream(se::StreamExecutor* executor);
+ StatusOr<StreamPool::Ptr> BorrowStream(int device_ordinal);
+ StatusOr<StreamPool::Ptr> BorrowStream(se::StreamExecutor* executor);
// Returns a function to borrow a stream, as `BorrowStream` above does.
// Purely for convenience, the caller could rather make this anonymous
// function itself.
- std::function<StatusOr<StreamPtr>(int)> StreamBorrower() {
+ std::function<StatusOr<StreamPool::Ptr>(int)> StreamBorrower() {
return [this](int device_ordinal) { return BorrowStream(device_ordinal); };
}
@@ -169,7 +167,7 @@ class Backend {
tensorflow::mutex mu_;
// Mapping from stream executor to stream pools, used by `BorrowStream` above.
- std::map<se::StreamExecutor*, Pool<se::Stream>> stream_pools_ GUARDED_BY(mu_);
+ std::map<se::StreamExecutor*, StreamPool> stream_pools_ GUARDED_BY(mu_);
// The default memory allocator to use.
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index f7b4c1405d..7cf05ca443 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -235,7 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
- sum, /*replica_group_ids=*/{}, /*barrier=*/""));
+ sum, /*replica_group_ids=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/tensorflow::gtl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 14c54ddd13..16e99b5722 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
- // Special handling for cross-replica-sum which can have a tuple output.
+ // Special handling for cross-replica-sum and sort which can have a tuple
+ // output.
Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleSort(HloInstruction* sort) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
@@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// conversions between F32 and BF16 to make it supported.
Status HandleInstruction(HloInstruction* hlo);
+ // Handle instructions with tuple outputs by examining each output
+ // independently.
+ Status HandleMultipleOutputs(HloInstruction* hlo);
+
// Inserts a conversion HLO that changes the given HLO's output type.
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
HloComputation* computation);
@@ -148,22 +154,35 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
if (!ShapeUtil::IsTuple(crs->shape())) {
return HandleInstruction(crs);
+ } else {
+ return HandleMultipleOutputs(crs);
}
+}
+
+Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
+ if (!ShapeUtil::IsTuple(sort->shape())) {
+ return HandleInstruction(sort);
+ } else {
+ return HandleMultipleOutputs(sort);
+ }
+}
- std::vector<PrimitiveType> operand_types(crs->operand_count());
- std::vector<PrimitiveType> output_types(crs->operand_count());
+Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
+ HloInstruction* hlo) {
+ std::vector<PrimitiveType> operand_types(hlo->operand_count());
+ std::vector<PrimitiveType> output_types(hlo->operand_count());
int64 f32_count = 0;
int64 bf16_count = 0;
bool has_unsupported_bf16_operand = false;
bool has_unsupported_bf16_output = false;
- for (int64 i = 0; i < crs->operand_count(); ++i) {
- operand_types[i] = crs->operand(i)->shape().element_type();
- output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ operand_types[i] = hlo->operand(i)->shape().element_type();
+ output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
if (operand_types[i] == F32) {
f32_count += 1;
} else if (operand_types[i] == BF16) {
bf16_count += 1;
- if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
+ if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
has_unsupported_bf16_operand = true;
}
}
@@ -171,7 +190,7 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
f32_count += 1;
} else if (output_types[i] == BF16) {
bf16_count += 1;
- if (!bfloat16_support_->SupportsBF16Output(*crs)) {
+ if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
has_unsupported_bf16_output = true;
}
}
@@ -185,43 +204,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
if (operand_types[i] != BF16) {
return false;
}
- if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
+ if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
return true;
}
- if (bfloat16_support_->SupportsMixedPrecisions(*crs)) {
+ if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
return false;
}
return has_unsupported_bf16_operand || has_unsupported_bf16_output ||
f32_count > 0;
};
- for (int64 i = 0; i < crs->operand_count(); ++i) {
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (should_convert_operand(i)) {
- TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
+ TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
f32_count += 1;
bf16_count -= 1;
}
}
if (!has_unsupported_bf16_output &&
- (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 ||
+ (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 ||
bf16_count == 0)) {
return Status::OK();
}
- std::vector<HloInstruction*> materialized_users = crs->users();
- std::vector<HloInstruction*> output_elements(crs->operand_count());
- auto original_shape = crs->shape();
- for (int64 i = 0; i < crs->operand_count(); ++i) {
- auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
+ std::vector<HloInstruction*> materialized_users = hlo->users();
+ std::vector<HloInstruction*> output_elements(hlo->operand_count());
+ auto original_shape = hlo->shape();
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i});
if (output_types[i] != BF16) {
output_elements[i] = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+ HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
continue;
}
subshape->set_element_type(F32);
auto gte = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(*subshape, crs, i));
+ HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
output_elements[i] =
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
@@ -229,11 +248,11 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple(output_elements));
- // Use the crs' shape temporarily, in order to pass checks in
+ // Use the hlo' shape temporarily, in order to pass checks in
// ReplaceUseWith.
- *tuple->mutable_shape() = crs->shape();
+ *tuple->mutable_shape() = hlo->shape();
for (auto* user : materialized_users) {
- TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
+ TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple));
}
*tuple->mutable_shape() = original_shape;
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 830f26422b..f9f1f64998 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -251,7 +251,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
- /*replica_group_ids=*/{}, /*barrier=*/""));
+ /*replica_group_ids=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/tensorflow::gtl::nullopt));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
@@ -265,6 +266,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
}
+TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
+ Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
+ Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024});
+
+ HloInstruction* key = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_shape, "key"));
+ HloInstruction* value = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, s32_shape, "value"));
+
+ HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort(
+ ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value));
+ HloInstruction* gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(Normalize(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), gte);
+ EXPECT_EQ(gte->shape().element_type(), BF16);
+ EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
+ EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
+}
+
// Tests that the normalization should not cause unsupported mixed precision due
// to resolving unsupported BF16 operand.
TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index b21c83a07f..2fb401c428 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -215,7 +215,12 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
return false;
}
- if (ValueTypeAfterChange(value) == BF16) {
+ // We use the original type for the value because we are going to examine
+ // the uses of it, instead of the value itself. If ValueTypeAfterChange()
+ // were used, it would cause problems when there are aliasing buffers, i.e.,
+ // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
+ // tentative change to BF16 even if the uses require F32.
+ if (value->shape().element_type() == BF16) {
continue;
}
for (const HloUse& use : value->uses()) {
@@ -566,6 +571,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
}
visited_computations->insert(visited_in_while.begin(),
visited_in_while.end());
+ } else if (hlo->opcode() == HloOpcode::kFusion) {
+ ResolveInconsistencyOfAliasingBuffersHelper(
+ hlo->fused_instructions_computation(), visited_computations);
}
}
// Now adjust parameters of called computations.
@@ -769,8 +777,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
// propagation in reverse topological order.
for (auto comp_it = computations_topological_order.rbegin();
comp_it != computations_topological_order.rend(); ++comp_it) {
- if ((*comp_it)->IsFusionComputation()) {
- // Fusion computations are handled when visiting the fusion instruction.
+ if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
continue;
}
auto insts = (*comp_it)->MakeInstructionPostOrder();
@@ -778,6 +785,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
DetermineInstructionPrecision(*inst_it,
/*skip_parameters=*/true);
}
+ computations_visited_in_backward_pass_.insert(*comp_it);
}
// It's possible that an instruction does not define a buffer, but the
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index aeafb25ad7..69b654d30e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -508,6 +508,63 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
EXPECT_FALSE(OutputsBF16(dot));
}
+// Tests that if the while condition prevents using BF16, no changes should be
+// made to the while body and thus the fusion node inside it.
+TEST_F(BFloat16PropagationTest,
+ ConditionPreventsPropagationForFusionInsideWhile) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param1"));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+
+ auto builder_cond = HloComputation::Builder("cond");
+ auto cond_param = builder_cond.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "cond_param"));
+ builder_cond.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1}))));
+ auto cond = module->AddEmbeddedComputation(builder_cond.Build());
+
+ auto builder_body = HloComputation::Builder("body");
+ auto body_param = builder_body.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "body_param"));
+ auto body_transpose = builder_body.AddInstruction(
+ HloInstruction::CreateTranspose(shape, body_param, {0, 1}));
+
+ auto builder_f = HloComputation::Builder("fusion");
+ HloInstruction* a_f =
+ builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1}));
+ auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
+ auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion(
+ shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f));
+ auto body = module->AddEmbeddedComputation(builder_body.Build());
+
+ auto while_hlo = builder.AddInstruction(
+ HloInstruction::CreateWhile(shape, cond, body, add));
+
+ auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_FALSE(OutputsBF16(add));
+ EXPECT_FALSE(OutputsBF16(body_fusion));
+ EXPECT_FALSE(OutputsBF16(body_param));
+ EXPECT_FALSE(OutputsBF16(body_transpose));
+ EXPECT_FALSE(OutputsBF16(a_f));
+}
+
// Tests that BF16 is propagated properly through while computations with
// tuple-shaped input/output.
TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
@@ -553,10 +610,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot = builder_body.AddInstruction(
+ auto body_dot1 = builder_body.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ auto body_dot2 = builder_body.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs));
+ auto body_transpose = builder_body.AddInstruction(
+ HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
builder_body.AddInstruction(
- HloInstruction::CreateTuple({body_dot, body_rhs}));
+ HloInstruction::CreateTuple({body_dot1, body_transpose}));
auto body = module->AddEmbeddedComputation(builder_body.Build());
auto while_hlo = builder.AddInstruction(
@@ -575,9 +636,11 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(lhs));
EXPECT_FALSE(OutputsBF16(rhs));
- EXPECT_TRUE(OutputsBF16(body_dot));
+ EXPECT_TRUE(OutputsBF16(body_dot1));
EXPECT_TRUE(OutputsBF16(body_lhs));
EXPECT_FALSE(OutputsBF16(body_rhs));
+ EXPECT_FALSE(OutputsBF16(body_dot2));
+ EXPECT_FALSE(OutputsBF16(body_transpose));
EXPECT_TRUE(OutputsBF16(cond_lhs));
EXPECT_FALSE(OutputsBF16(cond_rhs));
EXPECT_TRUE(OutputsBF16(add0));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index b4c7cf0dd8..e4d2e73b99 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -817,8 +817,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
}
Status BufferAssigner::AssignBuffersForComputation(
- const HloComputation* computation, const DebugOptions& debug_options,
- bool is_thread_local,
+ const HloComputation* computation, bool is_thread_local,
const FlatSet<const LogicalBuffer*>& colocated_buffers,
const FlatSet<BufferAllocation::Index>& colocated_allocations,
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
@@ -1342,11 +1341,25 @@ BufferAssigner::MergeColocatedBufferSets(
auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
&buffer_size,
&is_entry_parameter](int64 i, int64 j) {
- // Do not merge if one of the sets includes live outs or entry parameters.
+ // Do not merge if one of the sets includes live outs, entry parameters or
+ // constants.
+ //
+ // Buffer liveness does not report the correct live range for entry
+ // parameter and live out buffers so we have to special case them here. On
+ // backends that support constant buffer allocations, constant buffers are
+ // assigned globals in readonly storage so we can't merge colocated buffer
+ // sets containing constants with colocated buffer sets containing writing
+ // instructions or other constants.
+ //
+ // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to
+ // the caller of the executable so we can't write to entry parameters
+ // either, and the argument for not merging constants also applies to entry
+ // parameters.
for (int64 key : {i, j}) {
for (auto& buffer : colocated_buffer_sets[key]) {
if (buffer_liveness.MaybeLiveOut(*buffer) ||
- is_entry_parameter(*buffer)) {
+ is_entry_parameter(*buffer) ||
+ buffer->instruction()->opcode() == HloOpcode::kConstant) {
return true;
}
}
@@ -1664,7 +1677,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
buffers_to_assign_sequentially;
for (auto* computation : global_computations) {
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
- computation, module->config().debug_options(),
+ computation,
/*is_thread_local=*/false, colocated_buffers, colocated_allocations,
&buffers_to_assign_sequentially, assignment.get()));
}
@@ -1685,7 +1698,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
continue;
}
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
- computation, module->config().debug_options(),
+ computation,
/*is_thread_local=*/true, colocated_buffers, colocated_allocations,
/*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 4fcf1fc73d..94495290c1 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -543,8 +542,7 @@ class BufferAssigner {
// true, then all assigned buffers have the is_thread_local flag set to
// true.
Status AssignBuffersForComputation(
- const HloComputation* computation, const DebugOptions& debug_options,
- bool is_thread_local,
+ const HloComputation* computation, bool is_thread_local,
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
colocated_allocations,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index dea855d39a..eccb146a0d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1923,6 +1923,74 @@ ENTRY %test_module {
EXPECT_NE(slice_param, slice_while1);
}
+TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
+ const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+
+ const char* module_str = R"(
+HloModule test_module
+
+%cond.v0 {
+ %param = s32[] parameter(0)
+ ROOT %constant = pred[] constant(true)
+}
+
+%cond.v1 {
+ %param.0 = s32[] parameter(0)
+ ROOT %constant.0 = pred[] constant(true)
+}
+
+%body.v0 {
+ ROOT %param.1 = s32[] parameter(0)
+}
+
+%body.v1 {
+ %param.2 = s32[] parameter(0)
+ ROOT add = s32[] add(%param.2, %param.2)
+}
+
+ENTRY %test_module {
+ %constant.42 = s32[] constant(42)
+ %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
+ %mul = s32[] multiply(%while.0, %while.0)
+ %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
+ ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ // Run CopyInsertion and check if the graph constructed above doesn't need
+ // any copies inserted for BufferAssignment to run.
+ int64 instruction_count = module->instruction_count();
+ CopyInsertion copy_insertion;
+ ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
+ ASSERT_EQ(instruction_count, module->instruction_count());
+
+ // Get the instructions in the module.
+ const HloInstruction* bcast = module->entry_computation()->root_instruction();
+ const HloInstruction* constant =
+ module->entry_computation()->GetInstructionWithName("constant.42");
+ ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
+ const HloInstruction* while1 = bcast->operand(0);
+ ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
+ const HloInstruction* while0 = while1->operand(0)->operand(0);
+ ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
+
+ // Run buffer assignment.
+ auto assignment = RunBufferAssignment(module.get());
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
+ assignment->GetUniqueSlice(constant, {}));
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
+ assignment->GetUniqueSlice(while0, {}));
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
+ assignment->GetUniqueSlice(while1, {}));
+
+ // The constant slice is part of the while0's colocation set (init value), but
+ // not merged into the while1's colocation set.
+ EXPECT_EQ(slice_constant, slice_while0);
+ EXPECT_NE(slice_constant, slice_while1);
+}
+
// Tests that the colocated buffers for while instructions are properly assigned
// during buffer assignment such that the result tuple elements are not assigned
// to the same buffer.
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index d26486fcfe..187ce568cb 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -29,9 +29,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+using tensorflow::strings::StrAppend;
+using tensorflow::strings::StrCat;
+
namespace xla {
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
@@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
return std::move(assignment);
}
+string DeviceAssignment::ToString() const {
+ string output = StrCat("Computations: ", computation_count(),
+ " Replicas: ", replica_count(), "\n");
+ for (int computation = 0; computation < computation_count(); ++computation) {
+ StrAppend(&output, "Computation ", computation, ": ");
+ for (int replica = 0; replica < replica_count(); ++replica) {
+ StrAppend(&output, operator()(replica, computation), " ");
+ }
+ StrAppend(&output, "\n");
+ }
+ return output;
+}
+
StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
int replica_count,
int computation_count) {
diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h
index 737d00e93e..c899ffb9dc 100644
--- a/tensorflow/compiler/xla/service/computation_placer.h
+++ b/tensorflow/compiler/xla/service/computation_placer.h
@@ -55,6 +55,8 @@ class DeviceAssignment : public Array2D<int> {
// due to a StatusOr of an incomplete type (DeviceAssignment).
static StatusOr<std::unique_ptr<DeviceAssignment>> Deserialize(
const DeviceAssignmentProto& proto);
+
+ string ToString() const;
};
// A generic implementation of the XLA computation placer, which assigns device
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index bcac65ecda..504b61d134 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -252,6 +252,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
@@ -363,8 +364,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 29fa29d33a..b49ea89896 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -562,7 +562,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
BufferAssigner::Run(
module.get(),
xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment));
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -584,6 +586,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::move(computation_to_profile_idx),
&target_machine_features);
+ TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
+
for (auto embedded_computation :
entry_computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
@@ -747,7 +751,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferAssigner::Run(
module,
xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
- BufferSizeBytesFunction(), memory_alignment));
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -776,6 +782,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
&target_machine_features);
+
+ TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
+
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
@@ -832,7 +841,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate temporary buffers for parameters.
- if (allocation.is_entry_computation_parameter()) {
+ if (allocation.is_entry_computation_parameter() ||
+ allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 1093559892..81e17a5cd4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -88,6 +88,11 @@ Status CpuExecutable::AllocateBuffers(
continue;
}
+ if (allocation.is_constant()) {
+ VLOG(3) << "allocation #" << i << " is a constant";
+ continue;
+ }
+
if (allocation.is_thread_local()) {
VLOG(3) << "buffer #" << i << " is thread-local";
continue;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 9d9d3e04a9..a6d8551841 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@@ -175,25 +176,36 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
result_global, IrShapeType(literal.shape())->getPointerTo());
}
-Status IrEmitter::HandleConstant(HloInstruction* constant) {
- VLOG(2) << "HandleConstant: " << constant->ToString();
- const Literal& literal = constant->literal();
- llvm::Constant* global_for_const;
+Status IrEmitter::EmitConstantGlobals() {
+ for (const BufferAllocation& allocation : assignment_.Allocations()) {
+ if (!allocation.is_constant()) {
+ continue;
+ }
- auto it = emitted_literals_.find(&literal);
- if (it != emitted_literals_.end()) {
- global_for_const = it->second;
- } else {
- global_for_const = EmitGlobalForLiteral(literal);
- emitted_literals_[&literal] = global_for_const;
+ const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
+ llvm::Constant* global_for_const;
+ auto it = emitted_literals_.find(&literal);
+ if (it != emitted_literals_.end()) {
+ global_for_const = it->second;
+ } else {
+ global_for_const = EmitGlobalForLiteral(literal);
+ InsertOrDie(&emitted_literals_, &literal, global_for_const);
+ }
+
+ InsertOrDie(&constant_buffer_to_global_, allocation.index(),
+ global_for_const);
}
- emitted_value_[constant] = global_for_const;
- VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
- VLOG(2) << " its type: "
- << llvm_ir::DumpToString(*global_for_const->getType());
+
return Status::OK();
}
+Status IrEmitter::HandleConstant(HloInstruction* constant) {
+ VLOG(2) << "HandleConstant: " << constant->ToString();
+ // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
+ // of the constant.
+ return EmitTargetAddressForOp(constant);
+}
+
Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (ShapeUtil::IsTuple(copy->shape())) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
@@ -2712,6 +2724,10 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
}
+ if (allocation.is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, allocation.index());
+ }
+
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index cf7fa05b20..03bbb2afb5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -105,6 +105,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+ // Emit an LLVM global variable for every constant buffer allocation.
+ Status EmitConstantGlobals();
+
protected:
//
// The following methods implement the DfsHloVisitor interface.
@@ -560,6 +563,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
emitted_literals_;
+ tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
+ constant_buffer_to_global_;
+
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index eb83432f57..f227e4ae13 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index e6d25680b5..181cec3cdd 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -135,9 +135,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index be3fae5161..c433bddc84 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index 6794cfe297..228c3fac95 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -25,7 +25,7 @@ limitations under the License.
namespace xla {
AsyncExecution::AsyncExecution(Backend* backend,
- std::vector<Backend::StreamPtr> streams,
+ std::vector<StreamPool::Ptr> streams,
const ExecutionProfile& profile,
GlobalDataHandle result)
: backend_(CHECK_NOTNULL(backend)),
@@ -46,9 +46,10 @@ Status AsyncExecution::BlockUntilDone() const {
ExecutionTracker::ExecutionTracker() : next_handle_(1) {}
-ExecutionHandle ExecutionTracker::Register(
- Backend* backend, std::vector<Backend::StreamPtr> streams,
- const ExecutionProfile& profile, GlobalDataHandle result) {
+ExecutionHandle ExecutionTracker::Register(Backend* backend,
+ std::vector<StreamPool::Ptr> streams,
+ const ExecutionProfile& profile,
+ GlobalDataHandle result) {
tensorflow::mutex_lock lock(execution_mutex_);
int64 handle = next_handle_++;
auto inserted = handle_to_execution_.emplace(
diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h
index 4458152dd9..4e9b9f883e 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.h
+++ b/tensorflow/compiler/xla/service/execution_tracker.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/backend.h"
-#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -40,7 +40,7 @@ namespace xla {
// the stream when destructed.
class AsyncExecution {
public:
- AsyncExecution(Backend* backend, std::vector<Backend::StreamPtr> streams,
+ AsyncExecution(Backend* backend, std::vector<StreamPool::Ptr> streams,
const ExecutionProfile& profile, GlobalDataHandle result);
Status BlockUntilDone() const;
@@ -54,7 +54,7 @@ class AsyncExecution {
Backend* backend_;
// Stream on which the execution is launched.
- std::vector<Backend::StreamPtr> streams_;
+ std::vector<StreamPool::Ptr> streams_;
// Profile object of the execution to be returned to the user.
ExecutionProfile profile_;
@@ -72,7 +72,7 @@ class ExecutionTracker {
// Registers an execution with its backend, streams, and data handle to the
// execution result. Returns a handle for the registered execution.
ExecutionHandle Register(Backend* backend,
- std::vector<Backend::StreamPtr> stream,
+ std::vector<StreamPool::Ptr> stream,
const ExecutionProfile& profile,
GlobalDataHandle data);
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6f1e766d1c..e0aae3866b 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -114,11 +114,13 @@ cc_library(
srcs = ["hlo_to_ir_bindings.cc"],
hdrs = ["hlo_to_ir_bindings.h"],
deps = [
+ ":buffer_allocations",
":ir_emission_utils",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
@@ -142,6 +144,7 @@ cc_library(
],
deps = [
":backend_configs",
+ ":buffer_allocations",
":cudnn_convolution_runner",
":elemental_ir_emitter",
":gpu_constants",
@@ -163,6 +166,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:name_uniquer",
+ "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
@@ -248,7 +252,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
- "//tensorflow/compiler/xla/service:pool",
+ "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
@@ -323,6 +327,7 @@ cc_library(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
@@ -540,6 +545,38 @@ cc_library(
)
cc_library(
+ name = "pad_for_tensor_cores",
+ srcs = ["pad_for_tensor_cores.cc"],
+ hdrs = ["pad_for_tensor_cores.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ ],
+)
+
+tf_cc_test(
+ name = "pad_for_tensor_cores_test",
+ srcs = ["pad_for_tensor_cores_test.cc"],
+ deps = [
+ ":ir_emission_utils",
+ ":pad_for_tensor_cores",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
+ ],
+)
+
+cc_library(
name = "gpu_transfer_manager",
srcs = ["gpu_transfer_manager.cc"],
hdrs = ["gpu_transfer_manager.h"],
@@ -583,9 +620,11 @@ cc_library(
":ir_emission_utils",
":ir_emitter",
":multi_output_fusion",
+ ":pad_for_tensor_cores",
":pad_insertion",
":partition_assignment",
":stream_assignment",
+ ":stream_executor_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -809,6 +848,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:stream_executor_no_cuda",
],
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index b095d4cd73..537295292b 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -44,12 +44,22 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
num_buffers, device_ordinal, memory_allocator, buffer_assignment));
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
+ const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
+ const int64 expected_alignment = [&] {
+ if (allocation.is_entry_computation_parameter()) {
+ return kEntryParameterAlignBytes;
+ } else if (allocation.is_constant()) {
+ return kConstantBufferAlignBytes;
+ } else {
+ return kXlaAllocatedBufferAlignBytes;
+ }
+ }();
+
// If buffer #i's address is already registered (e.g. external arguments or
// result buffers), use that registered buffer.
if (registered_buffers_.count(i)) {
se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i);
- if (reinterpret_cast<uintptr_t>(address.opaque()) %
- kEntryParameterAlignBytes !=
+ if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment !=
0) {
return InternalError(
"Address of registered buffer %lld must be a multiple of %llx, but "
@@ -62,7 +72,6 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
// Allocate each allocation that might escape, or is the temp buffer.
bool seen_temp_buffer = false;
- const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) {
const int64 buffer_size = allocation.size();
se::DeviceMemoryBase buffer_address;
@@ -70,8 +79,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
OwningDeviceMemory buffer;
TF_ASSIGN_OR_RETURN(
buffer, memory_allocator->Allocate(device_ordinal, buffer_size));
- if (reinterpret_cast<uintptr_t>(buffer.opaque()) %
- kXlaAllocatedBufferAlignBytes !=
+ if (reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment !=
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
@@ -165,5 +173,10 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index,
buffers_[buffer_index] = buffer;
}
+bool ShouldEmitLiteralInLlvmIr(const Literal& literal) {
+ // LLVM can sometimes do interesting optimizations using scalar constants.
+ return ShapeUtil::IsScalar(literal.shape());
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
index 6366235025..f13eab0dd7 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -107,6 +107,12 @@ class BufferAllocations {
bool torn_down_ = false;
};
+// LLVM and PTXAS don't deal well with large constants, so we only emit very
+// small constants directly in LLVM IR. Larger constants are emitted with zero
+// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is
+// loaded.
+bool ShouldEmitLiteralInLlvmIr(const Literal& literal);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
index e6ddea6d25..7f0b030fec 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
@@ -30,5 +30,7 @@ const int64 kEntryParameterAlignBytes = 16;
const int64 kXlaAllocatedBufferAlignBytes =
tensorflow::Allocator::kAllocatorAlignment;
+const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes;
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
index 925e6927b6..6f5f1fa09c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
@@ -28,6 +28,9 @@ extern const int64 kEntryParameterAlignBytes;
// out (result) buffers.
extern const int64 kXlaAllocatedBufferAlignBytes;
+// Minimum alignment for constant buffers.
+extern const int64 kConstantBufferAlignBytes;
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index fbc1303085..75f414e47f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -48,80 +48,17 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
- HloDataflowAnalysis::Run(*module));
-
- // Make sure all operands of a library call are in memory instead of constants
- // in IR. Also, init values of while and conditional nodes cannot be
- // constants. Insert copies for any constants found at the operands of these
- // nodes.
- tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
+ // Check the assumption that the epsilon and feature_index constants of the
+ // CUDNN batchnorm op are not shared with other ops where we would replace
+ // them with a copy. These custom op calls are generated with the
+ // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them.
for (HloComputation* computation : module->computations()) {
for (HloInstruction* hlo : computation->instructions()) {
- // Inserts a copy of hlo->operand(n) if it's a constant.
- auto copy_operand_if_constant = [&](int64 n) -> Status {
- HloInstruction* operand = hlo->mutable_operand(n);
- // Skip the operands that have already been replaced with a copy in a
- // previous iteration (which is possible when a constant is used as an
- // operand in multiple places).
- if (ContainsKey(inserted_copies, operand)) {
- return Status::OK();
- }
- for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction()->IsConstant() &&
- !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) {
- HloInstruction* constant = value->defining_instruction();
- TF_ASSIGN_OR_RETURN(HloInstruction * copy,
- FindOrInsertCopy(constant));
- TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
- inserted_copies.insert(copy);
- changed = true;
- }
- }
- }
- return Status::OK();
- };
-
- if (IsCustomCallToDnnBatchNorm(*hlo)) {
- // The epsilon and feature_index operands to a CUDNN batchnorm op don't
- // need to be materialized in memory -- in fact, they must be constants.
- // These are the last two operands of all three batchnorm ops.
- for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
- } else if (ImplementedAsLibraryCall(*hlo) ||
- hlo->opcode() == HloOpcode::kCrossReplicaSum ||
- hlo->opcode() == HloOpcode::kWhile ||
- hlo->opcode() == HloOpcode::kConditional) {
- // For all other library calls, cross-replica-sum, while and conditional
- // ops materialize all the operands into memory. (Cross-replica-sum
- // gets its constant args materialized even if it's not implemented as a
- // libcall to simplify the implementation. It's slower, but we can
- // constant fold away constant args *anyway*, so we just need to make it
- // work.)
- for (int64 i = 0; i < hlo->operand_count(); ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
+ if (!IsCustomCallToDnnBatchNorm(*hlo)) {
+ continue;
}
- }
- }
-
- if (changed) {
- // Check the assumption that the epsilon and feature_index constants of the
- // CUDNN batchnorm op are not shared with other ops where we would replace
- // them with a copy. These custom op calls are generated with the
- // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them.
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* hlo : computation->instructions()) {
- if (!IsCustomCallToDnnBatchNorm(*hlo)) {
- continue;
- }
- for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count();
- ++i) {
- CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant);
- }
+ for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) {
+ CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant);
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 0cad2958c7..bb71c79fd7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -84,7 +85,7 @@ Status GpuExecutable::ExecuteThunks(
}
// Stream 0 indicates `main_stream` and substreams start from stream 1.
- std::vector<Pool<se::Stream>::SmartPtr> sub_streams;
+ std::vector<StreamPool::Ptr> sub_streams;
sub_streams.reserve(thunk_schedule_->StreamCount() - 1);
while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) {
sub_streams.emplace_back();
@@ -181,6 +182,55 @@ Status GpuExecutable::ExecuteThunks(
return Status::OK();
}
+StatusOr<const GpuExecutable::BufferAllocToDeviceMemoryMap*>
+GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
+ tensorflow::mutex_lock lock(module_handle_mutex_);
+ auto it = module_globals_.find(executor);
+ if (it != module_globals_.end()) {
+ return &it->second;
+ }
+
+ se::MultiModuleLoaderSpec module_spec;
+ if (!cubin().empty()) {
+ module_spec.AddCudaCubinInMemory(cubin());
+ }
+ module_spec.AddCudaPtxInMemory(ptx().c_str());
+
+ tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals;
+ se::ModuleHandle module_handle;
+ executor->LoadModule(module_spec, &module_handle);
+
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ const BufferAllocation& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_constant()) {
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase global,
+ executor->GetUntypedSymbol(
+ llvm_ir::ConstantBufferAllocationToGlobalName(allocation),
+ module_handle));
+ VLOG(3) << "Resolved global "
+ << llvm_ir::ConstantBufferAllocationToGlobalName(allocation)
+ << " to " << global.opaque();
+ InsertOrDie(&globals, i, global);
+
+ const Literal& literal =
+ llvm_ir::LiteralForConstantAllocation(allocation);
+ CHECK(ShapeUtil::IsArray(literal.shape()));
+ if (!ShouldEmitLiteralInLlvmIr(literal)) {
+ VLOG(3) << "H2D memcpy for constant with shape "
+ << ShapeUtil::HumanString(literal.shape());
+ TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D(
+ literal.untyped_data(), allocation.size(), &global));
+ }
+ }
+ }
+
+ module_handles_.emplace(executor,
+ se::ScopedModuleHandle(executor, module_handle));
+ return &module_globals_.emplace(executor, std::move(globals)).first->second;
+}
+
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
@@ -192,6 +242,10 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
}
BufferAllocations::Builder buffer_allocations_builder;
+ se::StreamExecutor* executor = run_options->stream()->parent();
+
+ TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor));
+
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
++i) {
const BufferAllocation& allocation = assignment_->GetAllocation(i);
@@ -213,8 +267,12 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
buffer_allocations_builder.RegisterBuffer(i, buffer);
}
+
+ if (allocation.is_constant()) {
+ buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i));
+ }
}
- se::StreamExecutor* executor = run_options->stream()->parent();
+
TF_ASSIGN_OR_RETURN(
auto buffer_allocations,
buffer_allocations_builder.Build(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 80ec38c3ac..c7ce6d0acb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -34,6 +34,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -66,7 +68,7 @@ class GpuExecutable : public Executable {
}
// Returns the compiled PTX for the computation.
- tensorflow::StringPiece ptx() const { return ptx_; }
+ const string& ptx() const { return ptx_; }
// Returns the cubin (compiled PTX) stored in this GpuExecutable. May be
// empty, in which case compilation is left up to the GPU driver.
@@ -98,6 +100,15 @@ class GpuExecutable : public Executable {
// computation. Uses points-to analysis from buffer assignment.
const PointsToSet& GetRootPointsToSet() const;
+ using BufferAllocToDeviceMemoryMap =
+ tensorflow::gtl::FlatMap<BufferAllocation::Index, se::DeviceMemoryBase>;
+
+ // Loads the PTX or CUBIN for this executable into `executor` and resolves the
+ // globals corresponding to constant buffers. Returns a map mapping buffer
+ // allocation indices to GPU pointers.
+ StatusOr<const BufferAllocToDeviceMemoryMap*> ResolveConstantGlobals(
+ stream_executor::StreamExecutor* executor);
+
// The LLVM IR, in string format, of the unoptimized module generated for this
// GpuExecutable. We save a string instead of an llvm::Module* because leaving
// llvm::Module* in a singleton can cause the heap checker to emit false
@@ -126,6 +137,14 @@ class GpuExecutable : public Executable {
// memory for every output/temp buffers.
const std::unique_ptr<const BufferAssignment> assignment_;
+ // Cache of module handles and constant buffer allocation maps used by
+ // `ResolveConstantGlobals`.
+ tensorflow::mutex module_handle_mutex_;
+ std::map<stream_executor::StreamExecutor*, se::ScopedModuleHandle>
+ module_handles_ GUARDED_BY(module_handle_mutex_);
+ std::map<stream_executor::StreamExecutor*, BufferAllocToDeviceMemoryMap>
+ module_globals_ GUARDED_BY(module_handle_mutex_);
+
TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
};
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 09ef62c87f..6ac5dfbcd5 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -31,20 +31,13 @@ limitations under the License.
namespace xla {
namespace gpu {
-using stream_executor::dnn::DataLayout;
-using stream_executor::dnn::FilterLayout;
-
-static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
- int major, minor;
- CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
- &minor));
- return major >= 7;
-}
+using se::dnn::DataLayout;
+using se::dnn::FilterLayout;
// Returns (input, filter, output) layouts.
static std::tuple<DataLayout, FilterLayout, DataLayout>
HeuristicLayoutAssignment(const HloInstruction* instr,
- stream_executor::StreamExecutor* stream_executor) {
+ se::StreamExecutor* stream_executor) {
// DataLayout and FilterLayout uses weird enum names. Translations:
// N <=> Batch or Output
// C <=> Depth or Input
@@ -52,31 +45,44 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// W <=> X
//
// Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
+ //
+ // If you have trouble keeping these straight, consider that all that matters
+ // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
+
+ constexpr auto kAllNCHW =
+ std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ constexpr auto kAllNHWC =
+ std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth);
- // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
- // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
- // changes, as well as the hardware updates.
+ // If we're not Volta or not fp16, the decision is easy: Use NCHW.
if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
IsVoltaOrLater(*stream_executor))) {
- return std::make_tuple(DataLayout::kBatchDepthYX,
- FilterLayout::kOutputInputYX,
- DataLayout::kBatchDepthYX);
+ return kAllNCHW;
}
+
VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
- // For BackwardInput that has stride, full NHWC layouts run significantly
- // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
+
+ // Empirically we've found with Volta and cudnn 7 that backward-input convs
+ // with stride are significantly faster with NCHW layouts.
//
- // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
- // NHWC).
+ // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
+ // which on paper gives good performance. However, there are two observations:
+ // * a mixed layout combination is more cuDNN-bug prone, based on empirical
+ // envidence.
+ // * we've also observed that for mixed layouts, cuDNN transposes data back
+ // and forth from a different layout combination. If we end up with
+ // transposes anyway, we prefer to have them in XLA, as they can be fused.
+ // TODO(timshen): Figure out the exact condition. This may be achieved by
+ // auto-tuning layouts offline.
if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
window_util::HasStride(instr->window())) {
- return std::make_tuple(DataLayout::kBatchYXDepth,
- FilterLayout::kOutputInputYX,
- DataLayout::kBatchDepthYX);
+ return kAllNCHW;
}
- return std::make_tuple(DataLayout::kBatchYXDepth,
- FilterLayout::kOutputYXInput,
- DataLayout::kBatchYXDepth);
+
+ // For other Volta f16 convolutions, use NHWC.
+ return kAllNHWC;
}
// Adds layout constraints on the cudnn custom-call instruction. The layout
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index 19420e590d..1722676930 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -37,10 +37,9 @@ void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
}
-uint64 GetCyclesTaken(
- std::stack<std::unique_ptr<se::Timer>>* timers,
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
- se::Stream* stream, double clock_rate_ghz) {
+uint64 GetCyclesTaken(std::stack<std::unique_ptr<se::Timer>>* timers,
+ const std::vector<StreamPool::Ptr>& sub_streams,
+ se::Stream* stream, double clock_rate_ghz) {
CHECK_GT(timers->size(), 0);
stream->ThenWaitFor(&sub_streams);
stream->ThenStopTimer(timers->top().get());
@@ -53,7 +52,7 @@ uint64 GetCyclesTaken(
HloExecutionProfiler::HloExecutionProfiler(
bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const std::vector<StreamPool::Ptr>& sub_streams,
const HloComputation* computation)
: do_profile_(do_profile),
profile_(profile),
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
index 6654850bef..80cde75f2b 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -38,10 +38,10 @@ class ScopedInstructionProfiler;
class HloExecutionProfiler {
public:
// If profiling is enabled, start an execution timer running.
- explicit HloExecutionProfiler(
- bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
- const HloComputation* computation);
+ explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile,
+ se::Stream* stream,
+ const std::vector<StreamPool::Ptr>& sub_streams,
+ const HloComputation* computation);
// If profiling is enabled, sets the total cycle count on the profile from the
// execution timer.
@@ -72,7 +72,7 @@ class HloExecutionProfiler {
double clock_rate_ghz_;
HloExecutionProfile* profile_;
se::Stream* stream_;
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
+ const std::vector<StreamPool::Ptr>& sub_streams_;
const HloComputation* computation_;
std::stack<std::unique_ptr<se::Timer>> timers_;
// Contains the HLO instructions for which we are currently measuring the
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 1b6315ec03..8c11cd0541 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -110,6 +112,12 @@ void HloToIrBindings::EmitBasePointersForHlos(
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type),
index);
+ } else if (slice.allocation()->is_constant()) {
+ llvm::Value* global_for_constant =
+ module_->getGlobalVariable(llvm_ir::AsStringRef(
+ llvm_ir::ConstantBufferAllocationToGlobalName(
+ *slice.allocation())));
+ BindHloToIrValue(*non_io_hlo, global_for_constant);
} else {
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
@@ -135,6 +143,14 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_);
}
+// Returns true if `value` has a name that should not be changed.
+static bool HasMeaningfulName(llvm::Value* value) {
+ if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
+ return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
+ }
+ return false;
+}
+
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
ShapeIndexView shape_index,
llvm::Value* ir_value) {
@@ -149,8 +165,13 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
} else {
typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo());
}
- ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
- typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));
+ if (!HasMeaningfulName(ir_value)) {
+ ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
+ }
+ if (!HasMeaningfulName(typed_ir_value)) {
+ typed_ir_value->setName(
+ llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));
+ }
return typed_ir_value;
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index af6259ae83..0f2c83aeb2 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -202,6 +202,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
}
+ return false;
}
// Other output fusions are not currently supported on GPUs.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 973848c336..1295e83c0c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -81,19 +81,6 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
}
Status IrEmitter::HandleConstant(HloInstruction* constant) {
- const Literal& literal = constant->literal();
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
- llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
- *module_, initializer->getType(),
- /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
- /*Name=*/"");
- VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
- << " emitted_value: " << llvm_ir::DumpToString(*global_for_const)
- << std::endl
- << " its type: "
- << llvm_ir::DumpToString(*global_for_const->getType());
- bindings_.BindHloToIrValue(*constant, global_for_const);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index db6a4e6f30..874c7cfb8a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
@@ -59,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
@@ -75,6 +77,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -230,11 +233,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
++arg_it;
kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
+
+ const int64 alignment = [&] {
+ if (alloc->is_entry_computation_parameter()) {
+ return kEntryParameterAlignBytes;
+ } else if (alloc->is_constant()) {
+ return kConstantBufferAlignBytes;
+ } else {
+ return kXlaAllocatedBufferAlignBytes;
+ }
+ }();
+
kernel->addParamAttr(
- arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment,
- alloc->is_entry_computation_parameter()
- ? kEntryParameterAlignBytes
- : kXlaAllocatedBufferAlignBytes));
+ arg_no,
+ llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
if (alloc->IsPreallocatedTempBuffer()) {
fn_arg->setName("temp_buf");
@@ -1762,6 +1774,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
.GetUniqueTopLevelSlice(tuple_element)
.ok();
});
+ // TODO(b/111689850): This logic isn't quite correct.
+ //
// Tuples (especially tuples that are the final result of a computation) can
// be so huge that if we were to emit a kernel that took each tuple element as
// a parameter, we would exceed the max allowable number of parameters to a
@@ -1769,9 +1783,9 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
// buffer, we collect their buffer addresses in a host array, and then copy
// that array to the tuple's buffer.
//
- // Some tuple elements (e.g. const or bitcast of const) might not have a
- // buffer -- their contents are stored in code. In that case, we fall back to
- // emitting kernels which have access to their buffer addresses in code.
+ // Some tuple elements might not have an unambiguous buffer (like the result
+ // of a select-tuple). In that case, we fall back to emitting kernels which
+ // have access to their buffer addresses in code.
if (all_tuple_elements_have_buffer) {
std::vector<BufferAllocation::Slice> tuple_element_buffers;
for (const HloInstruction* tuple_element : tuple->operands()) {
@@ -2054,28 +2068,34 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
std::vector<std::unique_ptr<Thunk>> thunks;
+ auto keys = sort->operand(0);
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
if (values != nullptr) {
- // TODO(b/26783907): Also sort the values by their corresponding key.
- return Unimplemented("Key/Value Sort is not implemented on GPU");
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
}
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
- // First copy the operand to the output, so that we can sort in-place.
- // TODO(b/26783907): Share buffer of output and operand when it is possible.
- if (sort->operand(0)->IsConstant()) {
- thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
- /*source_address=*/sort->operand(0)->literal().untyped_data(),
- /*destination_buffer=*/GetAllocationSlice(*sort),
- /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
- } else {
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*keys),
+ /*destination_buffer=*/keys_destination,
+ /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr));
+ }
+ if (values != nullptr && values_destination != GetAllocationSlice(*values)) {
+ // TODO(b/26783907): Figure out why we never seem to share buffers for
+ // key/value sort.
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
- /*source_address=*/GetAllocationSlice(*sort->operand(0)),
- /*destination_buffer=*/GetAllocationSlice(*sort),
- /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ /*source_address=*/GetAllocationSlice(*values),
+ /*destination_buffer=*/values_destination,
+ /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr));
}
int64 dimension_to_sort = sort->dimensions(0);
- int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort);
+ int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort);
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
auto index_type = b_.getInt64Ty();
@@ -2099,7 +2119,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
thunks.push_back(
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- sort->shape(), ir_emitter_context_->device_description());
+ keys->shape(), ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
ir_emitter_context_->llvm_module());
@@ -2111,8 +2131,11 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
}
TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
- dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask,
- &b_, &launch_dimensions));
+ dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index),
+ values != nullptr ? tensorflow::gtl::make_optional<IrArray>(
+ GetIrArray(*sort, *sort, values_shape_index))
+ : tensorflow::gtl::nullopt,
+ IrName(sort), xor_mask, &b_, &launch_dimensions));
}
}
@@ -2274,11 +2297,6 @@ GetHloBufferSlices(const HloInstruction* hlo,
// Adds entries for all subshapes of instr to `slices`.
auto add_slices_for = [&](const HloInstruction* instr) {
- // GPU constants don't have buffers; don't bother looking for one.
- if (instr->IsConstant()) {
- return;
- }
-
ShapeUtil::ForEachSubshape(
instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
if (slices.count({instr, index})) {
@@ -2340,21 +2358,25 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// We'll pass a pointer to each of the elements of `buffers` to our kernel, in
// this order.
- std::vector<const BufferAllocation*> buffers(buffers_needed.begin(),
- buffers_needed.end());
- std::sort(buffers.begin(), buffers.end(),
+ std::vector<const BufferAllocation*> non_constant_buffers;
+ c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
+ [](const BufferAllocation* allocation) {
+ return !allocation->is_constant();
+ });
+
+ std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
[](const BufferAllocation* a, const BufferAllocation* b) {
return a->index() < b->index();
});
- llvm::Function* kernel = BuildKernelPrototype(*inst, buffers);
+ llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
// Build a map from a BufferAllocation to the corresponding argument in our
// kernel.
std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
{
auto arg_it = kernel->arg_begin();
- auto buffers_it = buffers.begin();
+ auto buffers_it = non_constant_buffers.begin();
for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
kernel_args[*buffers_it] = arg_it;
}
@@ -2372,8 +2394,16 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
<< " is found in slice " << slice.ToString() << " at GTE index "
<< gte_index.ToString();
- llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {b_.getInt64(slice.offset())});
+ llvm::Value* loc;
+ if (slice.allocation()->is_constant()) {
+ loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
+ llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName(
+ *slice.allocation())));
+ CHECK_NE(loc, nullptr);
+ } else {
+ loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
+ }
// If gte_index is nonempty, we have to dereference `loc` to get to the
// value we're ultimately interested in.
@@ -2396,9 +2426,9 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
- return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- implements_whole_instruction ? inst : nullptr,
- unroll_factor);
+ return MakeUnique<KernelThunk>(
+ non_constant_buffers, llvm_ir::AsString(kernel->getName()),
+ implements_whole_instruction ? inst : nullptr, unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2635,7 +2665,17 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// If the init_value was fused into this reduce we have to generate it first.
if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
- TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
+
+ const Literal& literal = init_value_operand->literal();
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+
+ llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
+ *module_, initializer->getType(),
+ /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
+ /*Name=*/"");
+ global_for_const->setAlignment(kConstantBufferAlignBytes);
+ bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(
[=](const IrArray::Index& index) {
@@ -3367,5 +3407,47 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
return true;
}
+Status IrEmitterUnnested::EmitConstantGlobals() {
+ for (const BufferAllocation& allocation :
+ ir_emitter_context_->buffer_assignment().Allocations()) {
+ if (!allocation.is_constant()) {
+ continue;
+ }
+
+ const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
+ const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
+ llvm::ArrayType* global_type =
+ llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
+ llvm::Constant* initializer =
+ should_emit_initializer
+ ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
+ : llvm::ConstantAggregateZero::get(global_type);
+ if (should_emit_initializer) {
+ VLOG(3) << "Emitted initializer for constant with shape "
+ << ShapeUtil::HumanString(literal.shape());
+ }
+
+ // These globals will be looked up by name by GpuExecutable so we need to
+ // give them an external linkage. Not all of their uses are visible in the
+ // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely
+ // preserves their names (like available_externally), we also need to ensure
+ // that they stick around even if they're "unused".
+ //
+ // We may have to be more more clever here in the future if we notice that
+ // we're keeping around too many globals because of their linkage.
+ llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
+ global_type, /*isConstant=*/should_emit_initializer,
+ llvm::GlobalValue::ExternalLinkage,
+ /*Initializer=*/initializer,
+ llvm_ir::AsStringRef(
+ llvm_ir::ConstantBufferAllocationToGlobalName(allocation)));
+ global_for_const->setAlignment(kConstantBufferAlignBytes);
+ ir_emitter_context_->llvm_module()->getGlobalList().push_back(
+ global_for_const);
+ }
+
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 616d8a2206..5254419907 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -92,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
KernelThunk* thunk);
+ // Emits LLVM global variables corresponding to constant instructions.
+ Status EmitConstantGlobals();
+
private:
// Builds the appropriate thunk for the instruction hlo and returns the owning
// pointer to it. The caller needs to make sure `inst` outlives the lifetime
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 6fef720853..c67dcbce77 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -113,7 +113,7 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
// of input parameters differ. In such situtations it is beneficial not to fuse.
// We consider input params with maximum rank only. Params with smaller ranks
// will be broadcasted and have not been observed to cause data locality issues.
-// TODO(b/110927656): Improve reduce emitters to remove this limitation.
+// TODO(b/111977086): Improve reduce emitters to remove this limitation.
bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
int64 max_rank = 0;
const Layout* max_rank_layout;
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 2eefadebcd..7a683ede54 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -52,9 +52,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -199,6 +201,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
+ if (IsVoltaOrLater(*stream_exec)) {
+ pipeline.AddPass<PadForTensorCores>();
+ // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element
+ // pairs that TupleSimplifier fixes.
+ pipeline.AddPass<TupleSimplifier>();
+ }
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
@@ -540,11 +548,13 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> buffer_assignment,
- BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(),
- BufferSizeBytesFunction(),
- /*color_alignment=*/[](LogicalBuffer::Color) {
- return kXlaAllocatedBufferAlignBytes;
- }));
+ BufferAssigner::Run(
+ module.get(), hlo_schedule->ConsumeHloOrdering(),
+ BufferSizeBytesFunction(),
+ /*color_alignment=*/
+ [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::Stats::ToString() and BufferAssignment::ToString()
// include headers, so no need for us to print them ourselves.
XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString());
@@ -565,6 +575,9 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
HloComputation* entry_computation = module->entry_computation();
IrEmitterUnnested ir_emitter(module->config(), entry_computation,
&ir_emitter_context);
+
+ TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
+
{
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
new file mode 100644
index 0000000000..79f7d31816
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -0,0 +1,233 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+
+namespace xla {
+namespace gpu {
+
+using tensorflow::gtl::ArraySlice;
+
+// We want the input/output feature counts of an f16 conv to be factors of 8,
+// because without this cudnn can't use tensor cores on the conv.
+static constexpr int64 kDesiredNumFeaturesFactor = 8;
+
+// We won't pad a conv if doing so increases the total number of bytes in the
+// lhs, rhs, or result by more than this amount.
+//
+// TODO(jlebar): This number was tuned experimentally. It represents a
+// compromise on our current benchmarks; it speeds some up significantly, and
+// doesn't slow any down. But we can observe by changing this value that
+// there's additional room for speedups. Achieving those speedups without also
+// slowing other things down will likely require a more sophisticated heuristic,
+// possibly some form of auto-tuning.
+static constexpr double kMaxBytesTouchedIncrease = 1.2;
+
+// Pads the given dimensions in the given shape up to a multiple of
+// kDesiredNumFeaturesFactor.
+static Shape PadShape(Shape s, ArraySlice<int64> dims) {
+ for (int64 dim : dims) {
+ int64 dim_to_pad_size = s.dimensions(dim);
+ int64 new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ s.set_dimensions(dim, new_dim_to_pad_size);
+ }
+ return s;
+}
+
+// Creates and returns an HLO that zero-pads one or more dimensions in the given
+// instruction so that its shape is equal to the given shape.
+//
+// Padding is added to the end of each relevant dimension.
+//
+// If the instruction already has the given shape, simply returns it without an
+// intervening pad.
+static HloInstruction* PadInstruction(HloInstruction* instr,
+ const Shape& new_shape) {
+ HloComputation* comp = instr->parent();
+
+ const Shape& shape = instr->shape();
+ auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+
+ PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
+
+ bool added_padding = false;
+ for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) {
+ if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
+ continue;
+ }
+ CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
+ pad_config.mutable_dimensions(dim)->set_edge_padding_high(
+ new_shape.dimensions(dim) - shape.dimensions(dim));
+ added_padding = true;
+ }
+
+ if (!added_padding) {
+ return instr;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreatePad(new_shape, instr, zero, pad_config));
+}
+
+// Pads the input/output feature dimensions of the given cudnn convolution
+// custom-call to be multiples of kDesiredNumFeaturesFactor.
+static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
+ CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
+ << "conv must use 0 scratch bytes, i.e. this pass must be run "
+ "before CudnnConvolutionAlgorithmPicker.";
+
+ const auto& target = conv->custom_call_target();
+ const auto& dnums = conv->convolution_dimension_numbers();
+ auto* lhs = conv->mutable_operand(0);
+ auto* rhs = conv->mutable_operand(1);
+ const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+ Shape new_lhs_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBackwardFilterCallTarget) {
+ // LHS is "input".
+ return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardInputCallTarget);
+ // LHS is "output".
+ return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
+ }();
+
+ Shape new_rhs_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBackwardInputCallTarget) {
+ // RHS is "filter".
+ return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
+ // RHS is "output".
+ return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
+ }();
+
+ if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
+ ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
+ VLOG(3) << "No need to pad features of " << conv->ToString();
+ return false;
+ }
+
+ Shape new_result_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget) {
+ // Result is "output".
+ return PadShape(result_shape, {dnums.output_feature_dimension()});
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ // Result is "input".
+ return PadShape(result_shape, {dnums.input_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
+ // Result is "filter".
+ return PadShape(result_shape, {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ }();
+
+ // Check that padding wouldn't increase the total bytes read/written by this
+ // operation too much.
+ auto check_size_increase = [&](const Shape& old_shape,
+ const Shape& new_shape) {
+ int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+ int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+ if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) {
+ return true;
+ }
+ VLOG(3) << "Not padding convolution; doing so would change input / result "
+ "shape from "
+ << ShapeUtil::HumanString(old_shape) << " to "
+ << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+ << new_bytes / static_cast<double>(old_bytes) << "x > "
+ << kMaxBytesTouchedIncrease << "x: " << conv->ToString();
+ return false;
+ };
+ if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
+ !check_size_increase(rhs->shape(), new_rhs_shape) ||
+ !check_size_increase(result_shape, new_result_shape)) {
+ return false;
+ }
+
+ // OK, let's do the transformation!
+
+ auto* new_lhs = PadInstruction(lhs, new_lhs_shape);
+ auto* new_rhs = PadInstruction(rhs, new_rhs_shape);
+ CHECK(new_lhs != lhs || new_rhs != rhs)
+ << "We should have had to pad either LHS or RHS.";
+
+ auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
+ return conv->parent()->AddInstruction(std::move(new_instr));
+ };
+
+ Shape new_conv_shape = ShapeUtil::MakeTupleShape(
+ {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
+ auto* new_conv =
+ add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs}));
+
+ // Slice the new conv result if necessary, keeping in mind that new_conv has
+ // tuple shape (new_result_shape, u8[0]).
+ if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
+ std::vector<int64> start_indices(result_shape.dimensions_size(), 0);
+ std::vector<int64> end_indices(result_shape.dimensions().begin(),
+ result_shape.dimensions().end());
+ std::vector<int64> strides(result_shape.dimensions_size(), 1);
+
+ auto* new_conv_result = add(
+ HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
+ auto* empty_temp_buffer =
+ add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8>({})));
+ auto* sliced_result = add(HloInstruction::CreateSlice(
+ result_shape, new_conv_result, start_indices, end_indices, strides));
+ new_conv =
+ add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
+ }
+
+ VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
+ << new_conv->ToString();
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv));
+ return true;
+}
+
+static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
+ std::vector<HloInstruction*> convs;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (IsCustomCallToDnnConvolution(*instr) &&
+ instr->operand(0)->shape().element_type() == F16) {
+ convs.push_back(instr);
+ }
+ }
+ return convs;
+}
+
+StatusOr<bool> PadForTensorCores::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* comp : module->MakeNonfusionComputations()) {
+ for (HloInstruction* conv : GetRelevantConvs(comp)) {
+ TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv));
+ changed |= result;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
new file mode 100644
index 0000000000..192359f026
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Ensures that f16 cudnn convolutions have input/output channel dimensions that
+// are multiples of 8, inserting pads/slices as necessary.
+//
+// This is useful primarily for Volta and newer GPUs, where tensor cores can
+// only be used if the channel dims are multiples of 8. It's probably the
+// opposite of useful on other GPUs, so you should check what GPU you're
+// targeting before running this pass.
+//
+// TODO(jlebar): Also pad dots.
+class PadForTensorCores : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override {
+ return "pad for tensor cores";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
new file mode 100644
index 0000000000..99e7580b82
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+using ::testing::_;
+
+using PadForTensorCoresTest = HloVerifiedTestBase;
+
+TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+
+ SCOPED_TRACE(module().ToString());
+ EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget,
+ op::Pad(op::Parameter(0), _),
+ op::Pad(op::Parameter(1), _)));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(),
+ ShapeUtil::MakeShape(F16, {10, 20, 30, 48})));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(),
+ ShapeUtil::MakeShape(F16, {2, 2, 48, 40})));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget,
+ op::Pad(op::Parameter(0), _),
+ op::Pad(op::Parameter(1), _)));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(),
+ ShapeUtil::MakeShape(F16, {10, 20, 30, 48})));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(),
+ ShapeUtil::MakeShape(F16, {2, 2, 40, 48})));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvForwardCallTarget, op::Parameter(0),
+ op::Pad(op::Parameter(1), _)))),
+ _));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardInputCallTarget, op::Parameter(0),
+ op::Pad(op::Parameter(1), _)))),
+ _)));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ output = f16[10,20,30,40] parameter(1)
+ result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardFilterCallTarget,
+ op::Pad(op::Parameter(0), _), op::Parameter(1)))),
+ _)));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ output = f16[10,20,30,41] parameter(1)
+ result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardFilterCallTarget,
+ op::Parameter(0), op::Pad(op::Parameter(1), _)))),
+ _)));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
index a50ddf6ac6..05b305ea4c 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -20,10 +20,17 @@ limitations under the License.
namespace xla {
namespace gpu {
-using stream_executor::dnn::DataLayout;
-using stream_executor::dnn::DataLayoutString;
-using stream_executor::dnn::FilterLayout;
-using stream_executor::dnn::FilterLayoutString;
+using se::dnn::DataLayout;
+using se::dnn::DataLayoutString;
+using se::dnn::FilterLayout;
+using se::dnn::FilterLayoutString;
+
+bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
+ int major, minor;
+ CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
+ &minor));
+ return major >= 7;
+}
StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
index 39a6a38d00..1fc46bafa1 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -25,18 +26,20 @@ limitations under the License.
namespace xla {
namespace gpu {
+// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU.
+bool IsVoltaOrLater(const se::StreamExecutor& stream_exec);
+
// Returns (input, filter, output) XLA Layout protos given the StreamExecutor
// layouts.
StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
- stream_executor::dnn::DataLayout input,
- stream_executor::dnn::FilterLayout filter,
- stream_executor::dnn::DataLayout output);
+ se::dnn::DataLayout input,
+ se::dnn::FilterLayout filter,
+ se::dnn::DataLayout output);
// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts.
-StatusOr<std::tuple<stream_executor::dnn::DataLayout,
- stream_executor::dnn::FilterLayout,
- stream_executor::dnn::DataLayout>>
+StatusOr<
+ std::tuple<se::dnn::DataLayout, se::dnn::FilterLayout, se::dnn::DataLayout>>
XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
const Layout& input, const Layout& filter,
const Layout& output);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 686c3c16c9..4fad3f46cf 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -111,8 +111,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index ba5cd2d84d..9072b30317 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 1315a4183a..d81d87e7dc 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -57,6 +57,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
while (true) {
// Invoke thunk sequence for while 'condition' computation.
profiler->StartHloComputation();
+ VLOG(3) << "Executing condition computation";
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
buffer_allocations, stream, profiler));
profiler->FinishHloComputation(hlo_instruction()->while_condition());
@@ -64,6 +65,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// Copy the result of condition computation and break the loop if 'false'.
bool condition_result;
stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool));
+ VLOG(3) << "condition_result = " << condition_result;
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError(
@@ -78,6 +80,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// We measure the time of one execution of the while body computation. The
// while body may be executed more than once, the last measurement "wins".
profiler->StartHloComputation();
+ VLOG(3) << "Executing body computation";
// Invoke thunk sequence for while 'body' computation, and pass on
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index b2241cd423..2c854eea18 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/local_service.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index de1a32d8bd..bbfb0c253f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -1017,19 +1017,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
if (user->opcode() == HloOpcode::kFusion) {
+ if (fusion_can_share_buffer_ != nullptr) {
+ return fusion_can_share_buffer_(user, operand);
+ }
// Get the parameter associated with 'operand';
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
const HloValue& value = GetValueDefinedAt(fusion_param, operand_index);
- if (value.uses().size() != 1) {
- if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
- return true;
- }
- return false;
+ if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
+ return true;
}
- const HloUse& use = value.uses()[0];
-
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
user->fusion_kind() == HloInstruction::FusionKind::kInput) {
if (user->fused_expression_root()->opcode() ==
@@ -1039,13 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root at operand
// index 0.
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == 0;
- } else {
- return AreTransitiveUsesElementwiseOrTuple(fusion_param);
+ if (value.uses().size() == 1) {
+ const HloUse& use = value.uses()[0];
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ }
+ return false;
}
- } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
- user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
+ return AreTransitiveUsesElementwiseOrTuple(fusion_param);
+ }
+ if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
// Check if one operand of kAdd fused root is kDot or kConvolution.
@@ -1066,11 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root (at operand
// index 'other_add_operand_index').
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == other_add_operand_index;
- } else if (fusion_can_share_buffer_ != nullptr &&
- fusion_can_share_buffer_(user, operand)) {
- return true;
+ if (value.uses().size() == 1) {
+ const HloUse& use = value.uses()[0];
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == other_add_operand_index;
+ }
+ return false;
}
}
@@ -1081,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
+ if (user->opcode() == HloOpcode::kSort) {
+ // Only valid if there are no other users.
+ if (operand->users().size() != 1) {
+ return false;
+ }
+ // If we only sort keys, the output of sort is not a tuple, so we can always
+ // share the buffer.
+ if (user->operand_count() == 1) {
+ return true;
+ }
+ CHECK(!user_index.empty());
+ // Only share with the right tuple element buffer.
+ std::vector<int64> operand_indices = user->OperandIndices(operand);
+ return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
+ }
if (user->opcode() == HloOpcode::kCall) {
// Get all uses of value defined by 'operand' at 'operand_index'.
const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 37bc2d2c9d..2ec31a9148 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2232,6 +2232,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto sort =
+ builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape values_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto values = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto sort = builder.AddInstruction(HloInstruction::CreateSort(
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The buffer for the keys can be shared with the first tuple entry.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
+ // The buffer for the values can be shared with the second tuple entry.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
+ // Verify that the buffers are not shared with the "wrong" tuple entry.
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 5f575b24a1..cba72469ce 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 30bff286c2..70441b879d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -447,8 +447,7 @@ class HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id =
- tensorflow::gtl::nullopt);
+ const tensorflow::gtl::optional<int64>& all_reduce_id);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e4031f04d5..132e767420 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -224,8 +224,7 @@ class HloAllReduceInstruction : public HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id =
- tensorflow::gtl::nullopt);
+ const tensorflow::gtl::optional<int64>& all_reduce_id);
// Returns the group ids of each replica for CrossReplicaSum op.
const std::vector<int64>& replica_group_ids() const {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index e8eaf54949..d71d3c8170 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1132,13 +1132,24 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
+ optional<Window> window;
+ optional<ConvolutionDimensionNumbers> dnums;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
+ attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
+ attrs["dim_labels"] = {/*required=*/false,
+ AttrTy::kConvolutionDimensionNumbers, &dnums};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target));
+ if (window.has_value()) {
+ instruction->set_window(*window);
+ }
+ if (dnums.has_value()) {
+ instruction->set_convolution_dimension_numbers(*dnums);
+ }
break;
}
case HloOpcode::kHostCompute: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 1f0572c576..1c08c51220 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1016,6 +1016,17 @@ ENTRY Iota {
}
)"
+},
+// custom-call with window and dim_labels
+{
+"CustomCallWithWindowAndDimLabels",
+R"(HloModule CustomCallWithWindowAndDimLabels
+
+ENTRY Computation {
+ ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+}
+
+)"
}
});
// clang-format on
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 309a186e58..cdd3daf73b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -225,6 +225,15 @@ cc_library(
)
cc_library(
+ name = "buffer_assignment_util",
+ srcs = ["buffer_assignment_util.cc"],
+ hdrs = ["buffer_assignment_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ ],
+)
+
+cc_library(
name = "math_ops",
srcs = ["math_ops.cc"],
hdrs = ["math_ops.h"],
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index 2552ff4a6a..941d940684 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -58,7 +58,7 @@ ENTRY while3 {
CompileAndVerifyIr(hlo_string, R"(
; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
-; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]]
+; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
new file mode 100644
index 0000000000..4eb5d9fb47
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
+
+namespace xla {
+namespace llvm_ir {
+static const HloInstruction& InstrForConstantBufferAllocation(
+ const BufferAllocation& allocation) {
+ CHECK(allocation.is_constant());
+ HloInstruction* const_instr = nullptr;
+ for (const auto& buffer_offset_pair : allocation.assigned_buffers()) {
+ const LogicalBuffer* buffer = buffer_offset_pair.first;
+ // BufferAssignment may have assigned non-constant instructions to this
+ // allocation too so we can't CHECK this condition. E.g. for
+ //
+ // while(init = constant, body = identity, cond = ...)
+ //
+ // the LogicalBuffer for the kWhile instruction will have the same
+ // BufferAllocation as the LogicalBuffer for the (init) constant.
+ if (buffer->instruction()->opcode() == HloOpcode::kConstant) {
+ CHECK_EQ(const_instr, nullptr)
+ << const_instr->ToString() << " " << buffer->ToString();
+ const_instr = buffer->instruction();
+ }
+ }
+ CHECK_NE(const_instr, nullptr);
+ return *const_instr;
+}
+
+string ConstantBufferAllocationToGlobalName(
+ const BufferAllocation& allocation) {
+ string instr_name = InstrForConstantBufferAllocation(allocation).name();
+ for (char& c : instr_name) {
+ if (c == '.') {
+ c = '_';
+ }
+ }
+ return tensorflow::strings::StrCat("buffer_for_", instr_name);
+}
+
+const Literal& LiteralForConstantAllocation(
+ const BufferAllocation& allocation) {
+ return InstrForConstantBufferAllocation(allocation).literal();
+}
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h
new file mode 100644
index 0000000000..bfb6eecb87
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+
+namespace xla {
+namespace llvm_ir {
+// In XLA:GPU we map constant buffer allocations to globals in the generated
+// LLVM IR. This function gives us the name of the global variable a constant
+// buffer is mapped to. Not used on XLA:CPU.
+string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation);
+
+// Returns the Literal corresponding to `allocation`, which must be a constant
+// allocation.
+const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation);
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index 6f261c32f4..5187948e29 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -38,19 +39,18 @@ namespace llvm_ir {
namespace {
// Adds the inner comparison loop where we compare elements pointed to by
// 'keys_index' and 'compare_keys_index'.
-void EmitCompareLoop(int64 dimension_to_sort,
- const llvm_ir::IrArray::Index& keys_index,
- const llvm_ir::IrArray::Index& compare_keys_index,
- const llvm_ir::IrArray& keys_array, llvm::IRBuilder<>* b) {
- // TODO(b/26783907): parallelize this loop.
-
+void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
+ const IrArray::Index& compare_keys_index,
+ const IrArray& keys_array,
+ const tensorflow::gtl::optional<IrArray>& values_array,
+ llvm::IRBuilder<>* b) {
// if (is_smaller_index &&
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
llvm::Value* is_smaller_index = b->CreateICmpSLT(
keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
int64 dimension_to_sort_bound =
keys_array.GetShape().dimensions(dimension_to_sort);
- auto if_data = llvm_ir::EmitIfThenElse(
+ auto if_data = EmitIfThenElse(
b->CreateAnd(is_smaller_index,
b->CreateICmpSLT(compare_keys_index[dimension_to_sort],
keys_index.GetConstantWithIndexType(
@@ -63,19 +63,31 @@ void EmitCompareLoop(int64 dimension_to_sort,
auto comparison =
primitive_util::IsFloatingPointType(key_type)
// TODO(b/26783907): Figure out how to handle NaNs.
- ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
+ ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1)
: b->CreateICmp(primitive_util::IsSignedIntegralType(key_type)
? llvm::ICmpInst::ICMP_SLT
: llvm::ICmpInst::ICMP_ULT,
- key1, key2);
- auto min_key = b->CreateSelect(comparison, key1, key2);
- auto max_key = b->CreateSelect(comparison, key2, key1);
- keys_array.EmitWriteArrayElement(keys_index, min_key, b);
- keys_array.EmitWriteArrayElement(compare_keys_index, max_key, b);
+ key2, key1);
+ // If key2 < key1
+ auto if_smaller_data =
+ EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false);
+ SetToFirstInsertPoint(if_smaller_data.true_block, b);
+ // Swap key1 with key2.
+ keys_array.EmitWriteArrayElement(keys_index, key2, b);
+ keys_array.EmitWriteArrayElement(compare_keys_index, key1, b);
+ if (values_array.has_value()) {
+ // Also swap the values.
+ auto value1 = values_array.value().EmitReadArrayElement(keys_index, b);
+ auto value2 =
+ values_array.value().EmitReadArrayElement(compare_keys_index, b);
+ values_array.value().EmitWriteArrayElement(keys_index, value2, b);
+ values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b);
+ }
}
} // namespace
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
+ const tensorflow::gtl::optional<IrArray>& values_array,
tensorflow::StringPiece name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions) {
@@ -131,7 +143,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
compare_keys_index[dimension_to_sort] =
b->CreateXor(compare_index[0], xor_mask);
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
- keys_array, b);
+ keys_array, values_array, b);
return Status::OK();
};
if (launch_dimensions != nullptr) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index e75f9b08fb..8458744c6b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -30,6 +31,7 @@ namespace llvm_ir {
// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr,
// the inner compare loop will not be parallelized.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
+ const tensorflow::gtl::optional<IrArray>& values_array,
tensorflow::StringPiece name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions);
diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h
deleted file mode 100644
index 8e710ebb6d..0000000000
--- a/tensorflow/compiler/xla/service/pool.h
+++ /dev/null
@@ -1,84 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_
-#define TENSORFLOW_COMPILER_XLA_POOL_H_
-
-#include <functional>
-#include <vector>
-
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace xla {
-
-// Pool of values, which are created as needed and destroyed when the `Pool` is
-// destroyed
-template <typename T>
-class Pool {
- public:
- struct Deleter {
- void operator()(T* ptr) { pool->Deallocate(ptr); }
-
- Pool<T>* pool;
- };
-
- // A pointer to a taken element of a `Pool` which returns it to the pool on
- // destruction
- using SmartPtr = std::unique_ptr<T, Deleter>;
-
- // Constructs a `Pool` with given factory function, which need not be
- // thread-safe.
- explicit Pool(std::function<std::unique_ptr<T>()> factory)
- : factory_(factory) {}
-
- explicit Pool() : Pool([]() { return MakeUnique<T>(); }) {}
-
- // Returns a pointer to a value in the pool, creating a new value if none is
- // free. The returned smart pointer returns the element to the pool on
- // destruction.
- //
- // This method is thread-safe.
- SmartPtr Allocate() {
- tensorflow::mutex_lock lock(mu_);
- T* ptr;
- if (!xs_.empty()) {
- ptr = std::move(xs_.back()).release();
- xs_.pop_back();
- } else {
- ptr = factory_().release();
- }
- Deleter del = {this};
- return std::unique_ptr<T, Deleter>(ptr, del);
- }
-
- private:
- // Puts a pointer to a value back into the pool, leaving it free for future
- // use.
- //
- // This method is thread-safe.
- void Deallocate(T* ptr) {
- tensorflow::mutex_lock lock(mu_);
- xs_.push_back(std::unique_ptr<T>(ptr));
- }
-
- const std::function<std::unique_ptr<T>()> factory_ GUARDED_BY(mu_);
- std::vector<std::unique_ptr<T>> xs_ GUARDED_BY(mu_);
- tensorflow::mutex mu_;
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_POOL_H_
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 636013cbb5..ce070bc5b6 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -376,7 +377,7 @@ Service::ExecuteParallelAndRegisterResult(
ExecutionProfile* profile) {
// Streams where the computation are launched, so we can wait on the streams
// to complete.
- std::vector<Pool<se::Stream>::SmartPtr> streams;
+ std::vector<StreamPool::Ptr> streams;
std::vector<std::unique_ptr<se::Timer>> timers;
// Global data handles for the computation results, one for each computation.
@@ -403,7 +404,7 @@ Service::ExecuteParallelAndRegisterResult(
CHECK_EQ(replicas.size(), arguments[i].size());
std::vector<ScopedShapedBuffer> result_buffers;
for (int64 replica = 0; replica < replicas.size(); ++replica) {
- TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
backend->BorrowStream(replicas[replica]));
streams.push_back(std::move(stream));
@@ -515,13 +516,13 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile) {
// Set up streams.
- std::vector<Pool<se::Stream>::SmartPtr> streams;
+ std::vector<StreamPool::Ptr> streams;
TF_ASSIGN_OR_RETURN(auto replicas,
Replicas(*backend, SingleComputationDeviceHandle()));
TF_RET_CHECK(!replicas.empty());
for (se::StreamExecutor* executor : replicas) {
- TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
@@ -533,7 +534,7 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
- for (const Pool<se::Stream>::SmartPtr& stream : streams) {
+ for (const StreamPool::Ptr& stream : streams) {
ExecutableRunOptions options;
options.set_stream(stream.get());
options.set_device_ordinal(stream->parent()->device_ordinal());
diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h
index 7f3910cdb0..dbfed628bf 100644
--- a/tensorflow/compiler/xla/service/service_executable_run_options.h
+++ b/tensorflow/compiler/xla/service/service_executable_run_options.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_
#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/stream_executor/stream_executor.h"
@@ -27,8 +27,7 @@ namespace xla {
// data, now only a stream cache for GPU backend.
class ServiceExecutableRunOptions {
public:
- using StreamBorrower =
- std::function<StatusOr<Pool<se::Stream>::SmartPtr>(int)>;
+ using StreamBorrower = std::function<StatusOr<StreamPool::Ptr>(int)>;
ServiceExecutableRunOptions()
: ServiceExecutableRunOptions(ExecutableRunOptions()) {}
@@ -51,7 +50,7 @@ class ServiceExecutableRunOptions {
// Borrows a stream and returns a smart pointer which returns the stream on
// destruction.
- StatusOr<Pool<se::Stream>::SmartPtr> BorrowStream(int device_ordinal) const {
+ StatusOr<StreamPool::Ptr> BorrowStream(int device_ordinal) const {
return borrow_stream_
? borrow_stream_(device_ordinal)
: Status(tensorflow::error::UNIMPLEMENTED, "No stream cache");
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
new file mode 100644
index 0000000000..92bb21b816
--- /dev/null
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/stream_pool.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+
+namespace xla {
+
+StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
+ std::unique_ptr<se::Stream> stream;
+ {
+ tensorflow::mutex_lock lock(mu_);
+ if (!streams_.empty()) {
+ // Re-use an existing stream from the pool.
+ stream = std::move(streams_.back());
+ streams_.pop_back();
+ }
+ }
+
+ if (!stream) {
+ // Create a new stream.
+ stream = MakeUnique<se::Stream>(executor);
+ stream->Init();
+ }
+
+ // Return the stream wrapped in Ptr, which has our special deleter semantics.
+ PtrDeleter deleter = {this};
+ return Ptr(stream.release(), deleter);
+}
+
+void StreamPool::ReturnStream(se::Stream* stream) {
+ if (stream->ok()) {
+ tensorflow::mutex_lock lock(mu_);
+ streams_.emplace_back(stream);
+ } else {
+ // If the stream has encountered any errors, all subsequent
+ // operations on it will fail. So just delete the stream, and rely
+ // on new streams to be created in the future.
+ delete stream;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h
new file mode 100644
index 0000000000..7221d323a6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/stream_pool.h
@@ -0,0 +1,64 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Pool of stream_executor::Streams, which are created as needed and
+// destroyed when the pool is destroyed.
+class StreamPool {
+ public:
+ struct PtrDeleter {
+ void operator()(se::Stream* stream) { pool->ReturnStream(stream); }
+ StreamPool* pool;
+ };
+
+ // Stream pointer type returned by BorrowStream, which returns the
+ // stream to the pool on destruction.
+ using Ptr = std::unique_ptr<se::Stream, PtrDeleter>;
+
+ StreamPool() {}
+
+ // Returns a pointer to a stream in the pool, creating a new stream
+ // if none are available in the pool. The returned smart pointer
+ // returns the stream to the pool on destruction.
+ //
+ // This method is thread-safe.
+ Ptr BorrowStream(se::StreamExecutor* executor);
+
+ private:
+ // Puts a pointer to a stream back into the pool, leaving it free
+ // for future use. Streams that have previously encountered errors
+ // are deleted, and not returned to the pool.
+ //
+ // This method is thread-safe.
+ void ReturnStream(se::Stream* stream);
+
+ tensorflow::mutex mu_;
+ std::vector<std::unique_ptr<se::Stream>> streams_ GUARDED_BY(mu_);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
new file mode 100644
index 0000000000..aaf5c37b0d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -0,0 +1,136 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/stream_pool.h"
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace {
+
+class StreamPoolTest : public ::testing::Test {
+ protected:
+ std::unique_ptr<se::StreamExecutor> NewStreamExecutor() {
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie();
+ se::StreamExecutorConfig config(/*ordinal=*/0);
+ return platform->GetUncachedExecutor(config).ConsumeValueOrDie();
+ }
+};
+
+TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; }
+
+TEST_F(StreamPoolTest, OneStreamPool) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow and return a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ se::Stream* stream1_ptr = stream1.get();
+ EXPECT_TRUE(stream1->ok());
+ stream1 = nullptr;
+
+ // Borrow and return another stream.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ se::Stream* stream2_ptr = stream2.get();
+ EXPECT_TRUE(stream2->ok());
+ stream2 = nullptr;
+
+ // The underlying streams should be the same, since stream1 was the
+ // only stream available in the pool when stream2 was borrowed.
+ EXPECT_EQ(stream1_ptr, stream2_ptr);
+}
+
+TEST_F(StreamPoolTest, TwoStreamPool) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow two streams.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ se::Stream* stream1_ptr = stream1.get();
+ EXPECT_TRUE(stream1->ok());
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ se::Stream* stream2_ptr = stream2.get();
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different, since we haven't
+ // returned either of them yet.
+ EXPECT_NE(stream1_ptr, stream2_ptr);
+
+ // Return stream1 and borrow stream3.
+ stream1 = nullptr;
+ StreamPool::Ptr stream3 = pool.BorrowStream(executor.get());
+ se::Stream* stream3_ptr = stream3.get();
+ EXPECT_TRUE(stream3->ok());
+
+ // stream1 and stream3 should be the same.
+ EXPECT_EQ(stream1_ptr, stream3_ptr);
+ EXPECT_NE(stream2_ptr, stream3_ptr);
+
+ // Return stream2, and borrow stream4.
+ stream2 = nullptr;
+ StreamPool::Ptr stream4 = pool.BorrowStream(executor.get());
+ se::Stream* stream4_ptr = stream4.get();
+ EXPECT_TRUE(stream4->ok());
+
+ // Stream2 and stream4 should be the same.
+ EXPECT_EQ(stream2_ptr, stream4_ptr);
+ EXPECT_NE(stream3_ptr, stream4_ptr);
+}
+
+TEST_F(StreamPoolTest, BadStreamDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Force an error on the stream; here we call a method that requires
+ // DNN support, which we know the Host platform doesn't support.
+ stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1->ok());
+
+ // Return stream1 and borrow stream2.
+ stream1 = nullptr;
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ se::Stream* stream2_ptr = stream2.get();
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+
+ // Return stream2 and borrow stream3. The previous error on stream1
+ // has no effect on these streams, and they are the same.
+ stream2 = nullptr;
+ StreamPool::Ptr stream3 = pool.BorrowStream(executor.get());
+ se::Stream* stream3_ptr = stream3.get();
+ EXPECT_TRUE(stream3->ok());
+ EXPECT_EQ(stream2_ptr, stream3_ptr);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 7051a4cf51..58f767e913 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 990dfc410c..0effdc80a4 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -718,6 +718,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
// root at operand 0 or 1. Or...
// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
// 0.
+// (5) The 'user' of 'operand' is Sort, and it is the only user.
//
// (2) and (3) can only be determined if points-to analysis is available.
bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
@@ -783,6 +784,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
+ if (user->opcode() == HloOpcode::kSort) {
+ // Only valid if there are no other users.
+ if (operand->users().size() != 1) {
+ return false;
+ }
+ // If we only sort keys, the output of sort is not a tuple, so we can always
+ // share the buffer.
+ if (user->operand_count() == 1) {
+ return true;
+ }
+ CHECK(!user_index.empty());
+ // Only share with the right tuple element buffer.
+ std::vector<int64> operand_indices = user->OperandIndices(operand);
+ return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
+ }
if (user->opcode() == HloOpcode::kCall) {
// TODO(b/62548313): Remove when buffer assignment is module scoped and
// does not assign buffers to calls.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 0ac8df4271..2e5f646804 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1012,6 +1012,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto sort =
+ builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape values_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto values = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto sort = builder.AddInstruction(HloInstruction::CreateSort(
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The buffer for the keys can be shared with the first tuple entry.
+ EXPECT_TRUE(
+ points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
+ // The buffer for the values can be shared with the second tuple entry.
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
+ sort, {1}));
+ // Verify that the buffers are not shared with the "wrong" tuple entry.
+ EXPECT_FALSE(
+ points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
+ sort, {0}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 099431d949..42d52aee78 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -154,8 +154,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
@@ -192,8 +192,8 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -290,8 +290,8 @@ xla_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -314,8 +314,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -334,8 +334,8 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -356,9 +356,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -376,9 +376,10 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
@@ -395,8 +396,8 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -419,9 +420,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -445,8 +446,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -464,9 +465,9 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -483,8 +484,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -501,8 +502,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -519,9 +520,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -543,8 +544,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -562,8 +563,8 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -586,8 +587,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -612,8 +613,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -638,7 +639,7 @@ xla_test(
deps = [
":client_library_test_base",
":literal_test_util",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -658,7 +659,7 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -681,8 +682,8 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -702,8 +703,7 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -726,8 +726,8 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -750,8 +750,8 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -774,8 +774,8 @@ xla_test(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -796,7 +796,7 @@ CONVOLUTION_TEST_DEPS = [
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -839,8 +839,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -863,8 +863,8 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -892,10 +892,10 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -925,9 +925,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -951,8 +951,8 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -973,7 +973,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -992,8 +992,8 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1014,7 +1014,7 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
@@ -1045,8 +1045,8 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1066,9 +1066,9 @@ xla_test(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1097,9 +1097,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1124,9 +1124,9 @@ xla_test_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1165,9 +1165,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1188,7 +1188,7 @@ xla_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1242,8 +1242,8 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1261,7 +1261,7 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -1284,8 +1284,8 @@ xla_test(
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1307,8 +1307,8 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1328,9 +1328,8 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1347,8 +1346,8 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1364,8 +1363,8 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1389,8 +1388,8 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -1410,8 +1409,8 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -1440,8 +1439,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1460,7 +1459,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1483,9 +1482,9 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1509,8 +1508,8 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1529,8 +1528,8 @@ xla_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1551,8 +1550,8 @@ xla_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -1574,7 +1573,7 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1595,8 +1594,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -1614,8 +1613,8 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1637,8 +1636,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1658,8 +1657,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -1675,8 +1674,8 @@ xla_test(
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1689,8 +1688,8 @@ xla_test(
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1710,8 +1709,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1795,8 +1794,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
@@ -1823,8 +1822,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
@@ -1860,8 +1859,8 @@ xla_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1888,8 +1887,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
@@ -1924,7 +1923,7 @@ tf_cc_test(
":local_client_test_base",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/core:test_main",
@@ -1966,8 +1965,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1980,7 +1979,7 @@ xla_test(
name = "deep_graph_test",
srcs = ["deep_graph_test.cc"],
deps = [
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -2013,6 +2012,7 @@ xla_test(
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:generic_transfer_manager",
"//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@@ -2065,8 +2065,8 @@ xla_test(
":local_client_test_base",
":test_utils",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -2087,7 +2087,7 @@ xla_test(
":client_library_test_base",
":literal_test_util",
":xla_internal_test_main",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 3ae96fa1bc..74f2e36f82 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
index 8d15b7841b..caeb0bf49a 100644
--- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
index 71dbe4f0b6..af0b852239 100644
--- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
+++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 033382708a..d372d1ca43 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 747c82b502..6c20f654fe 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
index 20cb989751..0d7a3aa46a 100644
--- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc
+++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
index d531e8fa82..c6b5108fe9 100644
--- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
+++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 50dd574624..1d28e85b16 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index 05c1c361bb..b1d18210ea 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 0bc8facfe2..a4eb57fc7b 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 515c0201d1..59d917054b 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index edc1ba8a57..4a6e8a3124 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index f97008bee2..c898dacf48 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 2b407ed263..7c52c9fbbb 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 672fb06de6..5a06d061f0 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index e63d2480b6..be017477d8 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index d9d42bf061..b27c1044ba 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 71d72a9828..4937574831 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 0fb6853e3f..1adc68cc48 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 944366410b..7b6bbc4f57 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index a8b8f74ca9..5ed8122e00 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 8792e7781b..6784c16715 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 1dc6ff0f4f..5ef273e5a2 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 90f3d1b874..13c777835e 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
index 062b8cb8c4..5f234f36a8 100644
--- a/tensorflow/compiler/xla/tests/deallocation_test.cc
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index 6795130cd1..2db6503afa 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc
index 810947ab01..3f3e8ab712 100644
--- a/tensorflow/compiler/xla/tests/deep_graph_test.cc
+++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index d86fd7cc2d..cfd36abf47 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 88ac96d6b0..7f6f203a1b 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index e2c145b795..5116e60ca6 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 86bfaea4ef..bf1de02ba9 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
index 30dc639f11..39cc6c5927 100644
--- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc
index 0254ae1baa..c5bbbe778d 100644
--- a/tensorflow/compiler/xla/tests/fmax_test.cc
+++ b/tensorflow/compiler/xla/tests/fmax_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 607bcdd51e..792be0d3fc 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 2008d69237..b77bece85a 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 249a4b2493..51450314b6 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <cmath>
#include <vector>
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
index 4d82442f7e..5511190caf 100644
--- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
+++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 0df50150ae..e2cd5bcc5a 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 0b44090702..74494e60e8 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/Triple.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -92,9 +92,10 @@ int main(int argc, char** argv) {
// It's lame to hard-code the buffer assignments, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
- CHECK_EQ(result->buffer_sizes().size(), 2);
+ CHECK_EQ(result->buffer_sizes().size(), 3);
CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
+ CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
if (triple.isOSBinFormatELF()) {
// Check the ELF magic.
CHECK_EQ(result->object_file_data()[0], 0x7F);
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 5c3498c84c..1a823cf189 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc
index cdf70ee418..2d622242e6 100644
--- a/tensorflow/compiler/xla/tests/log_test.cc
+++ b/tensorflow/compiler/xla/tests/log_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 34bcaef513..0732e195d4 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 4fca90af77..da8c42d465 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
index e576f000ef..955dbef6dc 100644
--- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index d8c17202f2..ca21b0b2ba 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index bf3b5f2b65..f6c762e7a4 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 5c351b2d11..2fc7f816b5 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 3f98099be6..029af69573 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
index 526a38e8d1..fab2a65de1 100644
--- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
+++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 04c7f31646..531648fe3e 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 638b0825a1..2065271a7f 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -37,7 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 161b74a5c8..1bd6fdab31 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index f026ad6c42..d891451381 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index 7c0389cfa3..368f5583c9 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index a6e985293a..382d1b1ae7 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 23f0d26d93..41e49b4003 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 5a3bcaf086..e42c71eb28 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index ceb795219a..e3d4f98dd7 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
index 59409ab26e..1c01402798 100644
--- a/tensorflow/compiler/xla/tests/select_test.cc
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a593faca00..b8ad6668f8 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 8f424ae81f..a2f0338e25 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 0f86b7f20f..125513ddfd 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -60,7 +61,7 @@ class TransferManagerTest : public LocalClientTestBase {
}
protected:
- Backend::StreamPtr stream_ptr_;
+ StreamPool::Ptr stream_ptr_;
se::Stream* stream_;
private:
diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc
index 6ebb4324f8..fbe9d1b64a 100644
--- a/tensorflow/compiler/xla/tests/transpose_test.cc
+++ b/tensorflow/compiler/xla/tests/transpose_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index ad46eaa1c3..2fd70b72b5 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index a90a6fb0a5..20ae68ab74 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
index ea3aba6df1..ef1b1445bb 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index cacbe83b86..3848ec1684 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 0a39778002..c81c27891c 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7a75e5102c..0ee8e68c88 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -18,10 +18,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -133,7 +134,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
DeviceMemoryAllocator* allocator = backend->memory_allocator();
auto* transfer_manager = backend->transfer_manager();
TF_ASSERT_OK_AND_ASSIGN(
- Backend::StreamPtr stream_ptr,
+ StreamPool::Ptr stream_ptr,
backend->BorrowStream(backend->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 2a60750bda..180779670d 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
- var_name = True
+ var_name = tf.constant(True)
continue
"""
return templates.replace(template, var_name=var_name)
@@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
guarded_orelse = self._guard_if_present(node.orelse, break_var)
template = """
- var_name = False
+ var_name = tf.constant(False)
while test and not var_name:
body
else:
@@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
# the control variable is marked as used.
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
- var_name = False
+ var_name = tf.constant(False)
for target in iter_:
(var_name,)
body
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py
index c26ca2946c..fcae7d68c0 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/break_statements_test.py
@@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import break_statements
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
- with self.converted(test_fn, break_statements, {}) as result:
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self):
@@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- with self.converted(test_fn, break_statements, {}) as result:
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
@@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 11)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 11)
def test_nested_loops(self):
@@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 2)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 5)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 5)
def test_loop_orelse(self):
@@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 2)
- self.assertTransformedEquivalent(test_fn, 3)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py
index 958bde0a58..0476e97c15 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements.py
@@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
def visit_Continue(self, node):
self.set_local(CONTINUE_USED, True)
template = """
- var_name = True
+ var_name = tf.constant(True)
"""
return templates.replace(
template, var_name=self.get_local(CONTROL_VAR_NAME))
@@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
if self.get_local(CONTINUE_USED, False):
template = """
- var_name = False
+ var_name = tf.constant(False)
"""
control_var_init = templates.replace(template, var_name=continue_var)
nodes = control_var_init + nodes
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py
index 3a7c7d1486..37c15211b4 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py
@@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
- with self.converted(test_fn, continue_statements, {}) as result:
+ with self.converted(test_fn, continue_statements, {},
+ constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self):
@@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, [])
- self.assertTransformedEquivalent(test_fn, [1])
- self.assertTransformedEquivalent(test_fn, [2])
- self.assertTransformedEquivalent(test_fn, [1, 2, 3])
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1])
+ self.assertTransformedEquivalent(test_fn, [2])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
def test_nested(self):
@@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py
index 3f23662152..1936821394 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers.py
+++ b/tensorflow/contrib/autograph/converters/error_handlers.py
@@ -37,7 +37,8 @@ class ErrorRewritingTransformer(converter.Base):
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- if anno.hasanno(node, anno.Basic.ORIGIN):
+ if (anno.hasanno(node, anno.Basic.ORIGIN) and
+ len(self.enclosing_entities) <= 1):
template = """
try:
body
diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
index a3109fa5db..7e9cc54d4c 100644
--- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
+++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
@@ -392,7 +392,7 @@
"output_type": "stream",
"text": [
"Got error message: assertion failed: [Do not pass zero!]\n",
- "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n"
+ "\t [[{{node f/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n"
]
}
],
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index ee71f4f9ac..0adff76a9f 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -258,25 +258,27 @@ def to_graph(e,
# Avoid overwriting entities that have been transformed.
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
- compiled_fn = getattr(compiled_module, name)
+ compiled = getattr(compiled_module, name)
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
#
# Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
# symbol to the compiled module.
+ # TODO(mdan): Record this statically in the generated code.
+ # TODO(mdan): Rename this attribute to 'autograph_info__'
source_map_attribute_name = 'ag_source_map'
- if getattr(compiled_fn, source_map_attribute_name, None) is not None:
+ if getattr(compiled, source_map_attribute_name, None) is not None:
raise ValueError('cannot convert %s because is has an attribute '
'"%s", which is reserved for AutoGraph.' %
- (compiled_fn, source_map_attribute_name))
- setattr(compiled_fn, source_map_attribute_name,
+ (compiled, source_map_attribute_name))
+ setattr(compiled, source_map_attribute_name,
compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
- return compiled_fn
+ return compiled
def to_code(e,
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 4de7df6572..754baa87b0 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -280,6 +280,20 @@ class ApiTest(test.TestCase):
x = tc.test_method()
self.assertEqual(1, sess.run(x))
+ def test_converted_call_already_converted(self):
+
+ def f(x):
+ return x == 0
+
+ with self.test_session() as sess:
+ x = api.converted_call(f, False, False, {}, constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ converted_f = api.to_graph(f)
+ x = api.converted_call(converted_f, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
def test_to_graph_basic(self):
def test_fn(x, s):
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 57ec739a80..afb10d4d8b 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -48,6 +48,7 @@ from tensorflow.contrib.autograph.pyct import inspect_utils
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -70,6 +71,8 @@ def is_whitelisted_for_graph(o):
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
if m.__name__.startswith(prefix):
return True
+ if hasattr(o, 'autograph_info__'):
+ return True
return False
@@ -120,7 +123,16 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types):
'Entity "%s" has unsupported type "%s". Only functions and classes are '
'supported for now.' % (o, type(o)))
+ # TODO(mdan): This is temporary. it should be created using a converter.
+ # TODO(mdan): The attribute should be added with a helper, not directly.
+ # The helper can ensure there are no collisions.
+ template = '''
+ entity.autograph_info__ = {}
+ '''
+ node.extend(templates.replace(template, entity=name))
+
program_ctx.add_to_cache(o, node)
+
if program_ctx.recursive:
while True:
candidate = None
@@ -268,18 +280,18 @@ def function_to_graph(f,
context = converter.EntityContext(namer, entity_info, program_ctx)
node = node_to_graph(node, context, rewrite_errors=rewrite_errors)
- # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
+ # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
if not did_rename:
new_name = f.__name__
if node.name != f.__name__:
raise NotImplementedError('Strange corner case. Send us offending code!')
-
node.name = new_name
+
program_ctx.update_name_map(namer)
# TODO(mdan): Use this at compilation.
- return (node,), new_name, namespace
+ return [node], new_name, namespace
def node_to_graph(node, context, rewrite_errors=True):
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index bfc51365a3..1c5d4d09c4 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -61,7 +61,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
- fn_node, = nodes
+ fn_node, _ = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertIs(ns['b'], b)
@@ -115,10 +115,12 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestBase in program_ctx.dependency_cache)
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
self.assertEqual('TfTestBase',
- program_ctx.dependency_cache[TestBase][-1].name)
+ program_ctx.dependency_cache[TestBase][-2].name)
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass][-1].name)
+ program_ctx.dependency_cache[TestSubclass][-2].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self):
@@ -138,8 +140,10 @@ class ConversionTest(test.TestCase):
self.assertFalse(training.Model in program_ctx.dependency_cache)
self.assertEqual(
'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass][-1].name)
+ program_ctx.dependency_cache[TestSubclass][-2].name)
def test_entity_to_graph_lambda(self):
f = lambda a: a
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
index ca1441cf6f..a0938b3e5f 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -24,6 +24,7 @@ py_library(
deps = [
"//tensorflow/contrib/autograph/pyct",
"@gast_archive//:gast",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index cc039986c2..e42f679cfe 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Conversion to A-normal form."""
+"""Conversion to A-normal form.
+
+The general idea of A-normal form is that every intermediate value is
+explicitly named with a variable. For more, see
+https://en.wikipedia.org/wiki/A-normal_form.
+
+The specific converters used here are based on Python AST semantics as
+documented at https://greentreesnakes.readthedocs.io/en/latest/.
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gast
+import six
+
+from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
@@ -32,26 +44,375 @@ class DummyGensym(object):
# * the symbols generated so far
self._idx = 0
- def new_name(self, stem):
+ def new_name(self, stem='tmp'):
self._idx += 1
return stem + '_' + str(1000 + self._idx)
class AnfTransformer(transformer.Base):
- """Performs the actual conversion."""
+ """Performs the conversion to A-normal form (ANF)."""
- # TODO(mdan): Link to a reference.
- # TODO(mdan): Implement.
+ # The algorithm is a postorder recursive tree walk. Any given node A may, in
+ # general, require creation of a series B of Assign statements, which compute
+ # and explicitly name the intermediate values needed to compute the value of
+ # A. If A was already a statement, it can be replaced with the sequence B +
+ # [A]. If A was an expression, B needs to be propagated up the tree until a
+ # statement is encountered. Since the `ast.NodeTransformer` framework makes
+ # no provision for subtraversals returning side information, this class
+ # accumulates the sequence B in an instance variable.
- def __init__(self, entity_info):
- """Creates a transformer.
+ # The only other subtlety is that some Python statements (like `if`) have both
+ # expression fields (`test`) and statement list fields (`body` and `orelse`).
+ # Any additional assignments needed to name all the intermediate values in the
+ # `test` can be prepended to the `if` node, but assignments produced by
+ # processing the `body` and the `orelse` need to be kept together with them,
+ # and not accidentally lifted out of the `if`.
+
+ def __init__(self, entity_info, gensym_source=None):
+ """Creates an ANF transformer.
Args:
entity_info: transformer.EntityInfo
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names
"""
super(AnfTransformer, self).__init__(entity_info)
- self._gensym = DummyGensym(entity_info)
+ if gensym_source is None:
+ self._gensym = DummyGensym(entity_info)
+ else:
+ self._gensym = gensym_source(entity_info)
+ self._pending_statements = []
+
+ def _consume_pending_statements(self):
+ ans = self._pending_statements
+ self._pending_statements = []
+ return ans
+
+ def _add_pending_statement(self, stmt):
+ self._pending_statements.append(stmt)
+
+ _trivial_nodes = (
+ # Non-nodes that show up as AST fields
+ bool, six.string_types,
+ # Leaf nodes that are already in A-normal form
+ gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes,
+ gast.NameConstant, gast.Ellipsis,
+ # Binary operators
+ gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift,
+ gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv,
+ # Unary operators
+ gast.Invert, gast.Not, gast.UAdd, gast.USub,
+ # Comparison operators
+ gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE,
+ gast.Is, gast.IsNot, gast.In, gast.NotIn,
+ )
+
+ def _is_node_trivial(self, node):
+ if node is None:
+ return True
+ elif isinstance(node, self._trivial_nodes):
+ return True
+ elif isinstance(node, gast.keyword):
+ return self._is_node_trivial(node.value)
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._are_children_trivial(node)
+ return False
+
+ def _are_children_trivial(self, node):
+ for field in node._fields:
+ if not field.startswith('__'):
+ if not self._is_node_trivial(getattr(node, field)):
+ return False
+ return True
+
+ def _ensure_node_is_trivial(self, node):
+ if node is None:
+ return node
+ elif isinstance(node, self._trivial_nodes):
+ return node
+ elif isinstance(node, list):
+ # If something's field was actually a list, e.g., variadic arguments.
+ return [self._ensure_node_is_trivial(n) for n in node]
+ elif isinstance(node, gast.keyword):
+ node.value = self._ensure_node_is_trivial(node.value)
+ return node
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._ensure_fields_trivial(node)
+ elif isinstance(node, gast.expr):
+ temp_name = self._gensym.new_name()
+ temp_assign = templates.replace(
+ 'temp_name = expr', temp_name=temp_name, expr=node)[0]
+ self._add_pending_statement(temp_assign)
+ answer = templates.replace('temp_name', temp_name=temp_name)[0]
+ return answer
+ else:
+ raise ValueError('Do not know how to treat {}'.format(node))
+
+ def _ensure_fields_trivial(self, node):
+ for field in node._fields:
+ if field.startswith('__'):
+ continue
+ setattr(node, field, self._ensure_node_is_trivial(getattr(node, field)))
+ return node
+
+ def _visit_strict_statement(self, node, trivialize_children=True):
+ assert not self._pending_statements
+ node = self.generic_visit(node)
+ if trivialize_children:
+ self._ensure_fields_trivial(node)
+ results = self._consume_pending_statements()
+ results.append(node)
+ return results
+
+ def _visit_strict_expression(self, node):
+ node = self.generic_visit(node)
+ self._ensure_fields_trivial(node)
+ return node
+
+ # Note on code order: These are listed in the same order as the grammar
+ # elements on https://github.com/serge-sans-paille/gast
+
+ # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.
+
+ def visit_Return(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_Delete(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Assign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_AugAssign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Print(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_For(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.iter first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.iter)
+ node.iter = self._ensure_node_is_trivial(node.iter)
+ iter_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.iter, but that is both correct and
+ # cheap because by this point node.iter is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ iter_stmts.append(node)
+ return iter_stmts
+
+ def visit_AsyncFor(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncFor nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_While(self, node):
+ if not self._is_node_trivial(node.test):
+ msg = ('While with nontrivial test not supported yet '
+ '(need to avoid precomputing the test).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_If(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.test first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.test)
+ node.test = self._ensure_node_is_trivial(node.test)
+ condition_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.test, but that is both correct and
+ # cheap because by this point node.test is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ condition_stmts.append(node)
+ return condition_stmts
+
+ def visit_With(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.items first, because any statements created
+ # thereby need to live outside the body.
+ for item in node.items:
+ self.visit(item)
+ node.items = [self._ensure_node_is_trivial(n) for n in node.items]
+ contexts_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.items, but that is both correct and
+ # cheap because by this point node.items is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ contexts_stmts.append(node)
+ return contexts_stmts
+
+ def visit_AsyncWith(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncWith nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Raise(self, node):
+ return self._visit_strict_statement(node)
+
+ # Try should be correct by default.
+
+ def visit_Assert(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Assert nodes not supported yet '
+ '(need to avoid computing the test when assertions are off, and '
+ 'avoid computing the irritant when the assertion does not fire).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ # Import and ImportFrom should be correct by default.
+
+ def visit_Exec(self, node):
+ return self._visit_strict_statement(node)
+
+ # Global and Nonlocal should be correct by default.
+
+ def visit_Expr(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ # Pass, Break, and Continue should be correct by default.
+
+ def visit_BoolOp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial BoolOp nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_BinOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_UnaryOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Lambda(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Lambda nodes not supported '
+ '(cannot insert statements into lambda bodies).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_IfExp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial IfExp nodes not supported yet '
+ '(need to convert to If statement, to evaluate branches lazily '
+ 'and insert statements into them).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Dict(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Set(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_ListComp(self, node):
+ msg = ('ListComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_SetComp(self, node):
+ msg = ('SetComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_DictComp(self, node):
+ msg = ('DictComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_GeneratorExp(self, node):
+ msg = ('GeneratorExp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_Await(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Await nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Yield(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_YieldFrom(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial YieldFrom nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Compare(self, node):
+ if len(node.ops) > 1:
+ msg = ('Multi-ary compare nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self._visit_strict_expression(node)
+
+ def visit_Call(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Repr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Repr nodes not supported yet '
+ '(need to research their syntax and semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_FormattedValue(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial FormattedValue nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_JoinedStr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial JoinedStr nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Attribute(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Subscript(self, node):
+ return self._visit_strict_expression(node)
+
+ # Starred and Name are correct by default, because the right thing to do is to
+ # just recur.
+
+ def visit_List(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Tuple(self, node):
+ return self._visit_strict_expression(node)
+
+
+def transform(node, entity_info, gensym_source=None):
+ """Converts the given node to A-normal form (ANF).
+
+ The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form
+ The specific converters used here are based on Python AST semantics as
+ documented at https://greentreesnakes.readthedocs.io/en/latest/.
-def transform(node, entity_info):
- return AnfTransformer(entity_info).visit(node)
+ Args:
+ node: The node to transform.
+ entity_info: transformer.EntityInfo. TODO(mdan): What information does this
+ argument provide?
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names.
+ """
+ return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index aefbc69d8c..951974820c 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import textwrap
+
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
@@ -25,6 +27,22 @@ from tensorflow.contrib.autograph.pyct.common_transformers import anf
from tensorflow.python.platform import test
+class DummyGensym(object):
+ """A dumb gensym that suffixes a stem by sequential numbers from 1000."""
+
+ def __init__(self, entity_info):
+ del entity_info
+ # A proper implementation needs to account for:
+ # * entity_info.namespace
+ # * all the symbols defined in the AST
+ # * the symbols generated so far
+ self._idx = 0
+
+ def new_name(self, stem='tmp'):
+ self._idx += 1
+ return stem + '_' + str(1000 + self._idx)
+
+
class AnfTransformerTest(test.TestCase):
def _simple_source_info(self):
@@ -37,17 +55,349 @@ class AnfTransformerTest(test.TestCase):
owner_type=None)
def test_basic(self):
-
def test_function():
a = 0
return a
-
node, _ = parser.parse_entity(test_function)
node = anf.transform(node.body[0], self._simple_source_info())
result, _ = compiler.ast_to_object(node)
-
self.assertEqual(test_function(), result.test_function())
+ def assert_same_ast(self, expected_node, node, msg=None):
+ expected_source = compiler.ast_to_source(expected_node, indentation=' ')
+ expected_str = textwrap.dedent(expected_source).strip()
+ got_source = compiler.ast_to_source(node, indentation=' ')
+ got_str = textwrap.dedent(got_source).strip()
+ self.assertEqual(expected_str, got_str, msg=msg)
+
+ def assert_body_anfs_as_expected(self, expected_fn, test_fn):
+ # Testing the code bodies only. Wrapping them in functions so the
+ # syntax highlights nicely, but Python doesn't try to execute the
+ # statements.
+ exp_node, _ = parser.parse_entity(expected_fn)
+ node, _ = parser.parse_entity(test_fn)
+ node = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ exp_name = exp_node.body[0].name
+ # Ignoring the function names in the result because they can't be
+ # the same (because both functions have to exist in the same scope
+ # at the same time).
+ node.body[0].name = exp_name
+ self.assert_same_ast(exp_node, node)
+ # Check that ANF is idempotent
+ node_repeated = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ self.assert_same_ast(node_repeated, node)
+
+ def test_binop_basic(self):
+
+ def test_function(x, y, z):
+ a = x + y + z
+ return a
+
+ def expected_result(x, y, z):
+ tmp_1001 = x + y
+ a = tmp_1001 + z
+ return a
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_if_basic(self):
+
+ def test_function(a, b, c, e, f, g):
+ if a + b + c:
+ d = e + f + g
+ return d
+
+ def expected_result(a, b, c, e, f, g):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ if tmp_1002:
+ tmp_1003 = e + f
+ d = tmp_1003 + g
+ return d
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_nested_binop_and_return(self):
+
+ def test_function(b, c, d, e):
+ return (2 * b + c) + (d + e)
+
+ def expected_result(b, c, d, e):
+ tmp_1001 = 2 * b
+ tmp_1002 = tmp_1001 + c
+ tmp_1003 = d + e
+ tmp_1004 = tmp_1002 + tmp_1003
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_function_call_and_expr(self):
+
+ def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i))
+
+ def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ tmp_1001 = g + h
+ tmp_1002 = a + b
+ tmp_1003 = y * z
+ tmp_1004 = e + f
+ tmp_1005 = c + d
+ tmp_1006 = tmp_1001 + i
+ call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_print(self):
+
+ def test_function(a, b, c):
+ with a + b + c as d:
+ print(2 * d + 1)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ with tmp_1002 as d:
+ tmp_1003 = 2 * d
+ tmp_1004 = tmp_1003 + 1
+ print(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_local_definition_and_binary_compare(self):
+
+ def test_function():
+ def foo(a, b):
+ return 2 * a < b
+ return foo
+
+ def expected_result():
+ def foo(a, b):
+ tmp_1001 = 2 * a
+ tmp_1002 = tmp_1001 < b
+ return tmp_1002
+ return foo
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_list_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return [a + b, c + d, e + f]
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = [tmp_1001, tmp_1002, tmp_1003]
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_tuple_literal_and_unary(self):
+
+ def test_function(a, b, c, d, e, f):
+ return (a + b, -(c + d), e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = c + d
+ tmp_1002 = a + b
+ tmp_1003 = -tmp_1001
+ tmp_1004 = e + f
+ tmp_1005 = (tmp_1002, tmp_1003, tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_set_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return set(a + b, c + d, e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003)
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_dict_literal_and_repr(self):
+
+ def test_function(foo, bar, baz):
+ return repr({foo + bar + baz: 7 | 8})
+
+ def expected_result(foo, bar, baz):
+ tmp_1001 = foo + bar
+ tmp_1002 = tmp_1001 + baz
+ tmp_1003 = 7 | 8
+ tmp_1004 = {tmp_1002: tmp_1003}
+ tmp_1005 = repr(tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_field_read_and_write(self):
+
+ def test_function(a, d):
+ a.b.c = d.e.f + 3
+
+ def expected_result(a, d):
+ tmp_1001 = a.b
+ tmp_1002 = d.e
+ tmp_1003 = tmp_1002.f
+ tmp_1001.c = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_subscript_read_and_write(self):
+
+ def test_function(a, b, c, d, e, f):
+ a[b][c] = d[e][f] + 3
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a[b]
+ tmp_1002 = d[e]
+ tmp_1003 = tmp_1002[f]
+ tmp_1001[c] = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_augassign_and_delete(self):
+
+ def test_function(a, x, y, z):
+ a += x + y + z
+ del a
+ del z[y][x]
+
+ def expected_result(a, x, y, z):
+ tmp_1001 = x + y
+ a += tmp_1001 + z
+ del a
+ tmp_1002 = z[y]
+ del tmp_1002[x]
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_raise_yield_and_raise(self):
+
+ def test_function(a, c, some_computed, exception):
+ yield a ** c
+ raise some_computed('complicated' + exception)
+
+ def expected_result(a, c, some_computed, exception):
+ tmp_1001 = a ** c
+ yield tmp_1001
+ tmp_1002 = 'complicated' + exception
+ tmp_1003 = some_computed(tmp_1002)
+ raise tmp_1003
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_if_with_expressions(self):
+
+ def test_function(foo, bar, function, quux, quozzle, w, x, y, z):
+ with foo + bar:
+ function(x + y)
+ if quux + quozzle:
+ function(z / w)
+
+ def expected_result(foo, bar, function, quux, quozzle, w, x, y, z):
+ tmp_1001 = foo + bar
+ with tmp_1001:
+ tmp_1002 = x + y
+ function(tmp_1002)
+ tmp_1003 = quux + quozzle
+ if tmp_1003:
+ tmp_1004 = z / w
+ function(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_exec(self):
+
+ def test_function():
+ # The point is to test A-normal form conversion of exec
+ # pylint: disable=exec-used
+ exec('computed' + 5 + 'stuff', globals(), locals())
+
+ def expected_result():
+ # pylint: disable=exec-used
+ tmp_1001 = 'computed' + 5
+ tmp_1002 = tmp_1001 + 'stuff'
+ tmp_1003 = globals()
+ tmp_1004 = locals()
+ exec(tmp_1002, tmp_1003, tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_simple_while_and_assert(self):
+
+ def test_function(foo, quux):
+ while foo:
+ assert quux
+ foo = foo + 1 * 3
+
+ def expected_result(foo, quux):
+ while foo:
+ assert quux
+ tmp_1001 = 1 * 3
+ foo = foo + tmp_1001
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_for(self):
+
+ def test_function(compute, something, complicated, foo):
+ for foo in compute(something + complicated):
+ bar = foo + 1 * 3
+ return bar
+
+ def expected_result(compute, something, complicated, foo):
+ tmp_1001 = something + complicated
+ tmp_1002 = compute(tmp_1001)
+ for foo in tmp_1002:
+ tmp_1003 = 1 * 3
+ bar = foo + tmp_1003
+ return bar
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ # This test collects several examples where the definition of A-normal form
+ # implemented by this transformer is questionable. Mostly it's here to spell
+ # out what the definition is in these cases.
+ def test_controversial(self):
+
+ def test_function(b, c, d, f):
+ a = c + d
+ a.b = c + d
+ a[b] = c + d
+ a += c + d
+ a, b = c
+ a, b = c, d
+ a = f(c)
+ a = f(c + d)
+ a[b + d] = f.e(c + d)
+
+ def expected_result(b, c, d, f):
+ a = c + d
+ a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d)
+ a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d)
+ a += c + d # Should be a += tmp? (Definitely not tmp = c + d)
+ a, b = c # Should be a = c[0], b = c[1]? Or not?
+ a, b = c, d # Should be a = c, b = d? Or not?
+ a = f(c)
+ tmp_1001 = c + d
+ a = f(tmp_1001)
+ tmp_1002 = b + d
+ tmp_1003 = f.e
+ tmp_1004 = c + d
+ a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2?
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index 71079cfdc0..ccbe5fc954 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.autograph.utils import type_check
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
@@ -50,7 +51,9 @@ def dynamic_builtin(f, *args, **kwargs):
def dynamic_len(list_or_tensor):
"""Implementation of len using dynamic dispatch."""
- if tensor_util.is_tensor(list_or_tensor):
+ if _is_tensor_list(list_or_tensor):
+ return list_ops.tensor_list_length(list_or_tensor)
+ elif tensor_util.is_tensor(list_or_tensor):
shape = list_or_tensor.shape
if not shape.ndims:
raise ValueError(
@@ -59,6 +62,11 @@ def dynamic_len(list_or_tensor):
return len(list_or_tensor)
+def _is_tensor_list(list_or_tensor):
+ return (tensor_util.is_tensor(list_or_tensor)
+ and list_or_tensor.dtype == dtypes.variant)
+
+
def dynamic_int(num_or_tensor, **kwargs):
"""Implementation of int() using dynamic dispatch."""
if tensor_util.is_tensor(num_or_tensor):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index ef0e80cd09..f4a375328e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -147,6 +147,7 @@ py_library(
deps = [
":distillation_loss",
":estimator_utils",
+ ":model",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees:model_ops_py",
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index 7eb429b636..dbfa69edcb 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -26,6 +26,7 @@ from __future__ import print_function
import six
from tensorflow.contrib import layers
+from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
@@ -34,6 +35,7 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
@@ -62,27 +64,29 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram("%s_activation" % tag, value)
-def _dnn_tree_combined_model_fn(features,
- labels,
- mode,
- head,
- dnn_hidden_units,
- dnn_feature_columns,
- tree_learner_config,
- num_trees,
- tree_examples_per_layer,
- config=None,
- dnn_optimizer="Adagrad",
- dnn_activation_fn=nn.relu,
- dnn_dropout=None,
- dnn_input_layer_partitioner=None,
- dnn_input_layer_to_tree=True,
- dnn_steps_to_train=10000,
- predict_with_tree_only=False,
- tree_feature_columns=None,
- tree_center_bias=False,
- dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+def _dnn_tree_combined_model_fn(
+ features,
+ labels,
+ mode,
+ head,
+ dnn_hidden_units,
+ dnn_feature_columns,
+ tree_learner_config,
+ num_trees,
+ tree_examples_per_layer,
+ config=None,
+ dnn_optimizer="Adagrad",
+ dnn_activation_fn=nn.relu,
+ dnn_dropout=None,
+ dnn_input_layer_partitioner=None,
+ dnn_input_layer_to_tree=True,
+ dnn_steps_to_train=10000,
+ predict_with_tree_only=False,
+ tree_feature_columns=None,
+ tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
+ use_core_versions=False,
+ output_type=model.ModelBuilderOutputType.MODEL_FN_OPS):
"""DNN and GBDT combined model_fn.
Args:
@@ -156,6 +160,10 @@ def _dnn_tree_combined_model_fn(features,
partitioned_variables.min_max_variable_partitioner(
max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))
+ if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and
+ not use_core_versions):
+ raise ValueError("You must use core versions with Estimator Spec")
+
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
@@ -235,7 +243,8 @@ def _dnn_tree_combined_model_fn(features,
learner_config=tree_learner_config,
feature_columns=tree_feature_columns,
logits_dimension=head.logits_dimension,
- features=tree_features)
+ features=tree_features,
+ use_core_columns=use_core_versions)
with ops.name_scope("gbdt"):
predictions_dict = gbdt_model.predict(mode)
@@ -284,63 +293,96 @@ def _dnn_tree_combined_model_fn(features,
del loss
return control_flow_ops.no_op()
- if use_core_versions:
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_no_train_op_fn,
- logits=tree_train_logits)
- dnn_train_op = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_dnn_train_op_fn,
- logits=dnn_logits)
- dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
- dnn_train_op).train_op
+ if tree_center_bias:
+ num_trees += 1
+ finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- tree_train_op = head.create_estimator_spec(
- features=tree_features,
- mode=mode,
- labels=labels,
- train_op_fn=_tree_train_op_fn,
- logits=tree_train_logits)
- tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
- tree_train_op).train_op
+ if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_versions:
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_no_train_op_fn,
+ logits=tree_train_logits)
+ dnn_train_op = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_dnn_train_op_fn,
+ logits=dnn_logits)
+ dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ dnn_train_op).train_op
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
+ tree_train_op = head.create_estimator_spec(
+ features=tree_features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_tree_train_op_fn,
+ logits=tree_train_logits)
+ tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ tree_train_op).train_op
+
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_no_train_op_fn,
+ logits=tree_train_logits)
+ dnn_train_op = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_dnn_train_op_fn,
+ logits=dnn_logits).train_op
+ tree_train_op = head.create_model_fn_ops(
+ features=tree_features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_tree_train_op_fn,
+ logits=tree_train_logits).train_op
+
+ # Add the hooks
+ model_fn_ops.training_hooks.extend([
+ trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
+ tree_train_op),
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees)
+ ])
+ return model_fn_ops
+
+ elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC:
+ fusion_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_no_train_op_fn,
logits=tree_train_logits)
- dnn_train_op = head.create_model_fn_ops(
+ dnn_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_dnn_train_op_fn,
- logits=dnn_logits).train_op
- tree_train_op = head.create_model_fn_ops(
+ logits=dnn_logits)
+ tree_spec = head.create_estimator_spec(
features=tree_features,
mode=mode,
labels=labels,
train_op_fn=_tree_train_op_fn,
- logits=tree_train_logits).train_op
-
- if tree_center_bias:
- num_trees += 1
- finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
-
- model_fn_ops.training_hooks.extend([
- trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
- tree_train_op),
- trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees)
- ])
+ logits=tree_train_logits)
- return model_fn_ops
+ training_hooks = [
+ trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
+ tree_spec.train_op),
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees)
+ ]
+ fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
+ list(fusion_spec.training_hooks))
+ return fusion_spec
class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
@@ -697,3 +739,100 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
+ """Initializes a core version of DNNBoostedTreeCombinedEstimator.
+
+ Args:
+ dnn_hidden_units: List of hidden units per layer for DNN.
+ dnn_feature_columns: An iterable containing all the feature columns
+ used by the model's DNN.
+ tree_learner_config: A config for the tree learner.
+ num_trees: Number of trees to grow model to after training DNN.
+ tree_examples_per_layer: Number of examples to accumulate before
+ growing the tree a layer. This value has a big impact on model
+ quality and should be set equal to the number of examples in
+ training dataset if possible. It can also be a function that computes
+ the number of examples based on the depth of the layer that's
+ being built.
+ head: `Head` instance.
+ model_dir: Directory for model exports.
+ config: `RunConfig` of the estimator.
+ dnn_optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training the DNN. If `None`, will use the Adagrad
+ optimizer with default learning rate.
+ dnn_activation_fn: Activation function applied to each layer of the DNN.
+ If `None`, will use `tf.nn.relu`.
+ dnn_dropout: When not `None`, the probability to drop out a given
+ unit in the DNN.
+ dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
+ Defaults to `min_max_variable_partitioner` with `min_slice_size`
+ 64 << 20.
+ dnn_input_layer_to_tree: Whether to provide the DNN's input layer
+ as a feature to the tree.
+ dnn_steps_to_train: Number of steps to train dnn for before switching
+ to gbdt.
+ predict_with_tree_only: Whether to use only the tree model output as the
+ final prediction.
+ tree_feature_columns: An iterable containing all the feature columns
+ used by the model's boosted trees. If dnn_input_layer_to_tree is
+ set to True, these features are in addition to dnn_feature_columns.
+ tree_center_bias: Whether a separate tree should be created for
+ first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
+ """
+
+ def __init__(self,
+ dnn_hidden_units,
+ dnn_feature_columns,
+ tree_learner_config,
+ num_trees,
+ tree_examples_per_layer,
+ head,
+ model_dir=None,
+ config=None,
+ dnn_optimizer="Adagrad",
+ dnn_activation_fn=nn.relu,
+ dnn_dropout=None,
+ dnn_input_layer_partitioner=None,
+ dnn_input_layer_to_tree=True,
+ dnn_steps_to_train=10000,
+ predict_with_tree_only=False,
+ tree_feature_columns=None,
+ tree_center_bias=False,
+ dnn_to_tree_distillation_param=None):
+
+ def _model_fn(features, labels, mode, config):
+ return _dnn_tree_combined_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ dnn_hidden_units=dnn_hidden_units,
+ dnn_feature_columns=dnn_feature_columns,
+ tree_learner_config=tree_learner_config,
+ num_trees=num_trees,
+ tree_examples_per_layer=tree_examples_per_layer,
+ config=config,
+ dnn_optimizer=dnn_optimizer,
+ dnn_activation_fn=dnn_activation_fn,
+ dnn_dropout=dnn_dropout,
+ dnn_input_layer_partitioner=dnn_input_layer_partitioner,
+ dnn_input_layer_to_tree=dnn_input_layer_to_tree,
+ dnn_steps_to_train=dnn_steps_to_train,
+ predict_with_tree_only=predict_with_tree_only,
+ tree_feature_columns=tree_feature_columns,
+ tree_center_bias=tree_center_bias,
+ dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
+ use_core_versions=True)
+
+ super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 9b7acfa664..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
-
+from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
features = {
@@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=False,
+ tree_feature_columns=[core_feature_column.numeric_column("x")])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=True,
+ tree_feature_columns=[])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 38fa8c3834..2df879f924 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -26,6 +26,12 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
+# ================== Old estimator interface===================================
+# The estimators below were designed for old feature columns and old estimator
+# interface. They can be used with new feature columns and losses by setting
+# use_core_libs = True.
+
+
class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
"""An estimator using gradient boosted decision trees."""
@@ -356,9 +362,16 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
config=config,
feature_engineering_fn=feature_engineering_fn)
+# ================== New Estimator interface===================================
+# The estimators below use new core Estimator interface and must be used with
+# new feature columns and heads.
+
class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
- """An estimator using gradient boosted decision trees."""
+ """An estimator using gradient boosted decision trees.
+
+ Useful for training with user specified `Head`.
+ """
def __init__(self,
learner_config,
@@ -374,6 +387,36 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
output_leaf_index=False):
+ """Initializes a core version of GradientBoostedDecisionTreeEstimator.
+
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
+ """
def _model_fn(features, labels, mode, config):
return model.model_builder(
@@ -397,3 +440,87 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+
+ def _model_fn(features, labels, mode, config):
+ return model.ranking_model_builder(
+ features=features,
+ labels=labels,
+ mode=mode,
+ config=config,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': True,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
+ },
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
+
+ super(CoreGradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index f787d3cdb8..9e9febbbef 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -203,7 +203,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.predict(input_fn=_infer_ranking_train_input_fn)
-class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
+class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
def testTrainEvaluateInferDoesNotThrowError(self):
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
@@ -229,6 +229,34 @@ class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
est.evaluate(input_fn=_eval_input_fn, steps=1)
est.predict(input_fn=_eval_input_fn)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ est = estimator.CoreGradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ # Train for a few steps.
+ est.train(input_fn=_ranking_train_input_fn, steps=1000)
+ est.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ est.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 2fbe72951a..161cc42cb0 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -59,6 +59,8 @@ def model_builder(features,
* center_bias: Whether a separate tree should be created for first fitting
the bias.
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
Returns:
A `ModelFnOps` object.
@@ -126,14 +128,15 @@ def model_builder(features,
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
+
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- training_hooks = [
+ training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
- ]
+ finalized_trees))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
@@ -175,7 +178,12 @@ def model_builder(features,
return model_fn_ops
-def ranking_model_builder(features, labels, mode, params, config):
+def ranking_model_builder(features,
+ labels,
+ mode,
+ params,
+ config,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model for ranking.
Args:
@@ -199,6 +207,9 @@ def ranking_model_builder(features, labels, mode, params, config):
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+
Returns:
A `ModelFnOps` object.
@@ -326,31 +337,54 @@ def ranking_model_builder(features, labels, mode, params, config):
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
- if use_core_libs and callable(create_estimator_spec_op):
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
- model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
- gbdt_batch.LEAF_INDEX]
+ training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
+
finalized_trees, attempted_trees = (
gbdt_model_main.get_number_of_trees_tensor())
- model_fn_ops.training_hooks.append(
+ training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+
+ model_fn_ops.training_hooks.extend(training_hooks)
+ return model_fn_ops
+
+ elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
+ assert callable(create_estimator_spec_op)
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ return estimator_spec
+
return model_fn_ops
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
index 1bfeed3066..6d9a6ee5a0 100644
--- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
@@ -372,12 +372,18 @@ class GrowTreeEnsembleOp : public OpKernel {
return;
}
+ // Get the max tree depth.
+ const Tensor* max_tree_depth_t;
+ OP_REQUIRES_OK(context,
+ context->input("max_tree_depth", &max_tree_depth_t));
+ const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
+
// Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
- dropout_seed);
+ dropout_seed, max_tree_depth);
// Split tree nodes.
for (auto& split_entry : best_splits) {
@@ -494,7 +500,8 @@ class GrowTreeEnsembleOp : public OpKernel {
boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
- const float learning_rate, const uint64 dropout_seed) {
+ const float learning_rate, const uint64 dropout_seed,
+ const int32 max_tree_depth) {
const auto num_trees = ensemble_resource->num_trees();
if (num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
@@ -506,8 +513,7 @@ class GrowTreeEnsembleOp : public OpKernel {
tree_config->add_nodes()->mutable_leaf();
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
ensemble_resource->LastTreeMetadata();
- tree_metadata->set_is_finalized(
- learner_config_.constraints().max_tree_depth() <= 1);
+ tree_metadata->set_is_finalized(max_tree_depth <= 1);
tree_metadata->set_num_tree_weight_updates(1);
} else {
// The growable tree is by definition the last tree in the ensemble.
@@ -518,8 +524,7 @@ class GrowTreeEnsembleOp : public OpKernel {
<< num_trees - 1 << " of ensemble of " << num_trees << " trees.";
// Update growable tree metadata.
tree_metadata->set_num_layers_grown(new_num_layers);
- tree_metadata->set_is_finalized(
- new_num_layers >= learner_config_.constraints().max_tree_depth());
+ tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth);
}
UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed);
return ensemble_resource->LastTree();
diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
index f63c199ad6..22ac9edb72 100644
--- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
@@ -56,6 +56,7 @@ REGISTER_OP("GrowTreeEnsemble")
.Input("next_stamp_token: int64")
.Input("learning_rate: float")
.Input("dropout_seed: int64")
+ .Input("max_tree_depth: int32")
.Input("partition_ids: num_handlers * int32")
.Input("gains: num_handlers * float")
.Input("splits: num_handlers * string")
@@ -67,6 +68,8 @@ REGISTER_OP("GrowTreeEnsemble")
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input));
// Dropout seed.
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input));
+ // Maximum tree depth.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input));
return Status::OK();
})
.Doc(R"doc(
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index 3e524efbea..e39e1de8d1 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
# Dropout does not change anything here, tree is not finalized.
- dropout_probability=0.5).SerializeToString()
+ dropout_probability=0.5)
# Prepare handler inputs.
# Note that handlers 1 & 3 have the same gain but different splits.
@@ -321,9 +321,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -443,7 +444,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
# Dropout does not change anything here - tree is not finalized.
- dropout_probability=0.5).SerializeToString()
+ dropout_probability=0.5)
# Prepare handler inputs.
# Handler 1 only has a candidate for partition 1, handler 2 has candidates
@@ -472,9 +473,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -632,8 +634,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
- )
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
@@ -657,9 +658,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -773,8 +775,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
- )
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
# All handlers have negative gain.
@@ -794,9 +795,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
partition_ids=[handler1_partitions, handler2_partitions],
gains=[handler1_gains, handler2_gains],
splits=[handler1_split, handler2_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the ensemble to be empty.
@@ -839,8 +841,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
- )
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
# Note that handlers 1 & 3 have the same gain but different splits.
@@ -865,9 +866,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -946,8 +948,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
max_depth=2,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
- )
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
# All handlers have negative gain.
@@ -967,9 +968,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
partition_ids=[handler1_partitions, handler2_partitions],
gains=[handler1_gains, handler2_gains],
splits=[handler1_split, handler2_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1048,9 +1050,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
partition_ids=[handler1_partitions],
gains=[handler1_gains],
splits=[handler1_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the ensemble to be empty as post-pruning will prune
@@ -1094,8 +1097,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
max_depth=2,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE,
- growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString(
- )
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
# Second handler has positive gain.
@@ -1115,9 +1117,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
partition_ids=[handler1_partitions, handler2_partitions],
gains=[handler1_gains, handler2_gains],
splits=[handler1_split, handler2_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1194,9 +1197,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
partition_ids=[handler1_partitions],
gains=[handler1_gains],
splits=[handler1_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the negative gain split of partition 1 to be pruned and the
@@ -1335,7 +1339,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER,
# Dropout will have no effect, since the tree will not be fully grown.
- dropout_probability=1.0).SerializeToString()
+ dropout_probability=1.0)
# Prepare handler inputs.
# Handler 1 only has a candidate for partition 1, handler 2 has candidates
@@ -1364,9 +1368,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -1543,7 +1548,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE,
- dropout_probability=1.0).SerializeToString()
+ dropout_probability=1.0)
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
@@ -1567,9 +1572,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -1669,7 +1675,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
learner_config.constraints.max_number_of_unique_feature_columns = 3
- learner_config = learner_config.SerializeToString()
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
handler1_gains = np.array([7.62], dtype=np.float32)
@@ -1692,9 +1697,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
- learner_config=learner_config,
+ learner_config=learner_config.SerializeToString(),
dropout_seed=123,
- center_bias=True)
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth)
session.run(grow_op)
_, serialized = session.run(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index e08b230f46..19e053fcb6 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -380,6 +380,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._max_tree_depth = variables.Variable(
+ initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
@@ -1051,7 +1053,8 @@ class GradientBoostedDecisionTreeModel(object):
splits=split_info_list,
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
- center_bias=self._center_bias)
+ center_bias=self._center_bias,
+ max_tree_depth=self._max_tree_depth)
def _grow_ensemble_not_ready_fn():
# Don't grow the ensemble, just update the stamp.
@@ -1065,7 +1068,8 @@ class GradientBoostedDecisionTreeModel(object):
splits=[],
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
- center_bias=self._center_bias)
+ center_bias=self._center_bias,
+ max_tree_depth=self._max_tree_depth)
def _grow_ensemble_fn():
# Conditionally grow an ensemble depending on whether the splits
@@ -1105,6 +1109,9 @@ class GradientBoostedDecisionTreeModel(object):
def get_number_of_trees_tensor(self):
return self._finalized_trees, self._attempted_trees
+ def get_max_tree_depth(self):
+ return self._max_tree_depth
+
def train(self, loss, predictions_dict, labels):
"""Updates the accumalator stats and grows the ensemble.
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 8f521ffee4..f9dc3effd0 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -259,11 +259,11 @@ class TPUClusterResolver(ClusterResolver):
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
- (self._tpu, response['state']))
+ (compat.as_text(self._tpu), response['state']))
if 'health' in response and response['health'] != 'HEALTHY':
- raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
- response['health']))
+ raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
+ (compat.as_text(self._tpu), response['health']))
if 'networkEndpoints' in response:
worker_list = [
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 6c93487e0d..f6c928e2be 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -471,7 +471,6 @@ if (tensorflow_ENABLE_GPU)
${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h
- ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h
${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h
${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 5931c8a279..6c9ab6aeb8 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -219,8 +219,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
op_def)
#Use Graph's hidden methods to add the op
to_graph._record_op_seen_by_control_dependencies(new_op)
- for device_function in reversed(to_graph._device_function_stack):
+ # pylint: disable=protected-access
+ for device_function in to_graph._device_functions_outer_to_inner:
new_op._set_device(device_function(new_op))
+ # pylint: enable=protected-access
return new_op
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index a4914f4cde..42fc20ec01 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -515,10 +515,7 @@ def batch_and_drop_remainder(batch_size):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time
- # after 6/30/2018.
- batched = dataset.batch(batch_size)
- return _filter_irregular_batches(batch_size)(batched)
+ return dataset.batch(batch_size, drop_remainder=True)
return _apply_fn
@@ -553,11 +550,9 @@ def padded_batch_and_drop_remainder(batch_size,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)`
- # any time after 6/30/2018.
- batched = dataset.padded_batch(
- batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
- return _filter_irregular_batches(batch_size)(batched)
+ return dataset.padded_batch(
+ batch_size, padded_shapes=padded_shapes, padding_values=padding_values,
+ drop_remainder=True)
return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index f5d7e24ae2..cbe741de5a 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -101,6 +101,23 @@ py_library(
)
py_library(
+ name = "parameter_server_strategy",
+ srcs = ["parameter_server_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":cross_tower_ops",
+ ":mirrored_strategy",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
name = "one_device_strategy",
srcs = ["one_device_strategy.py"],
visibility = ["//tensorflow:internal"],
@@ -207,6 +224,35 @@ py_test(
],
)
+py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:run_config",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
index fe3df9cbb9..bcb977f640 100644
--- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
+++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
@@ -49,17 +49,23 @@ class CheckpointUtilsWithDistributionStrategyTest(
def testInitFromCheckpoint(self, distribution, in_tower_mode):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
- v1_value, _, _, _ = checkpoint_utils_test._create_checkpoints(
+ v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints(
session, checkpoint_dir)
def init_and_verify(g):
v1 = variable_scope.get_variable("new_var1", [1, 10])
+ v2 = variable_scope.get_variable(
+ "new_var2", [10, 10],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.MEAN)
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var1": "new_var1",
+ "var2": "new_var2"
})
with self.test_session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
+ self.assertAllEqual(v2_value, self.evaluate(v2))
with ops.Graph().as_default() as g, distribution.scope():
if in_tower_mode:
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index b0baf0dad1..b6037d2133 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -28,18 +28,37 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
+def check_destinations(destinations):
+ """Checks whether `destinations` is not None and not empty.
+
+ Args:
+ destinations: a DistributedValues, Variable, string or a list of strings.
+
+ Returns:
+ Boolean indicating whether `destinations` is not None and not empty.
+ """
+ # Calling bool() on a ResourceVariable is not allowed.
+ if isinstance(destinations, resource_variable_ops.ResourceVariable):
+ return bool(destinations.device)
+ return bool(destinations)
+
+
def validate_destinations(destinations):
- if not isinstance(destinations,
- (value_lib.DistributedValues, six.string_types, list)):
+ if not isinstance(
+ destinations,
+ (value_lib.DistributedValues, resource_variable_ops.ResourceVariable,
+ six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
- " a device string, a list of device strings or None")
+ " a tf.Variable object, a device string, a list of device "
+ "strings or None")
- if not destinations:
+ if not check_destinations(destinations):
raise ValueError("destinations can not be empty")
@@ -59,6 +78,8 @@ def _validate_value_destination_pairs(value_destination_pairs):
def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
+ elif isinstance(destinations, resource_variable_ops.ResourceVariable):
+ return [destinations.device]
elif isinstance(destinations, six.string_types):
return [device_util.resolve(destinations)]
else:
@@ -225,7 +246,10 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
def _reduce(self, aggregation, per_device_value, destinations):
- devices = get_devices_from(destinations or per_device_value)
+ if check_destinations(destinations):
+ devices = get_devices_from(destinations)
+ else:
+ devices = get_devices_from(per_device_value)
reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
self.accumulation_fn, aggregation)
@@ -508,7 +532,10 @@ class AllReduceCrossTowerOps(CrossTowerOps):
logging.WARN,
"Efficient allreduce is not supported for IndexedSlices.", 10)
- devices = get_devices_from(destinations or per_device_value)
+ if check_destinations(destinations):
+ devices = get_devices_from(destinations)
+ else:
+ devices = get_devices_from(per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
math_ops.add_n, aggregation)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index dcbc6b0878..eb2d102012 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import contextlib
import threading
-import six
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
@@ -60,6 +59,156 @@ class _RequestedStop(Exception):
pass
+# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# MirroredStrategy so that they are generally not allowed to use anything
+# specific to MirroredStrategy and thus can be shared with other distribution
+# strategies.
+
+
+# TODO(yuefengz): maybe create a common class for those who need to call this
+# _call_for_each_tower.
+def _call_for_each_tower(distribution, fn, *args, **kwargs):
+ """Run `fn` in separate threads, once per tower/worker device.
+
+ Args:
+ distribution: the DistributionStrategy object.
+ fn: function to run (will be run once per device, each in its own thread).
+ *args: positional arguments for `fn`
+ **kwargs: keyword arguments for `fn`.
+ `"run_concurrently"`: Boolean indicating whether executions of `fn`
+ can be run concurrently (under eager execution only), defaults to
+ `True`.
+
+ Returns:
+ Merged return value of `fn` across all towers.
+
+ Raises:
+ RuntimeError: If fn() calls get_tower_context().merge_call() a different
+ number of times from the available devices.
+ """
+ run_concurrently = kwargs.pop("run_concurrently", True)
+ if not context.executing_eagerly():
+ # Lots of TF library code isn't thread-safe in graph mode, and
+ # there is little to be gained by turning on multithreading when
+ # constructing a graph.
+ run_concurrently = False
+ # Needed for per-thread device, etc. contexts in graph mode.
+ ops.get_default_graph().switch_to_thread_local()
+ elif run_concurrently is None:
+ run_concurrently = True
+
+ coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
+
+ shared_variable_store = {}
+
+ # TODO(isaprykin): Create these threads once instead of during every run()
+ # call.
+ threads = []
+ for index, d in enumerate(distribution.worker_devices):
+ variable_creator_fn = shared_variable_creator.make_fn(
+ shared_variable_store, index)
+ t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
+ distribution, coord, d, variable_creator_fn, fn,
+ *values.select_device(d, args), **values.select_device(d, kwargs))
+ threads.append(t)
+
+ for t in threads:
+ t.start()
+
+ # When `fn` starts `should_run` event is set on _MirroredTowerThread
+ # (`MTT`) threads. The execution waits until
+ # `MTT.has_paused` is set, which indicates that either `fn` is
+ # complete or a `get_tower_context().merge_call()` is called. If `fn` is
+ # complete, then `MTT.done` is set to True. Otherwise, arguments
+ # of `get_tower_context().merge_call` from all paused threads are grouped
+ # and the `merge_fn` is performed. Results of the
+ # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
+ # Each such `get_tower_context().merge_call` call returns the
+ # `MTT.merge_result` for that thread when `MTT.should_run` event
+ # is reset again. Execution of `fn` resumes.
+
+ try:
+ with coord.stop_on_exception():
+ all_done = False
+ while not all_done and not coord.should_stop():
+ done = []
+ if run_concurrently:
+ for t in threads:
+ t.should_run.set()
+ for t in threads:
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ else:
+ for t in threads:
+ t.should_run.set()
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ if coord.should_stop():
+ return None
+ all_done = all(done)
+ if not all_done:
+ if any(done):
+ raise RuntimeError("Some towers made a different number of "
+ "tower_context().merge_call() calls.")
+ # get_tower_context().merge_call() case
+ merge_args = values.regroup({t.device: t.merge_args for t in threads})
+ merge_kwargs = values.regroup(
+ {t.device: t.merge_kwargs for t in threads})
+ # We capture the name_scope of the MTT when we call merge_fn
+ # to ensure that if we have opened a name scope in the MTT,
+ # it will be respected when executing the merge function. We only
+ # capture the name_scope from the first MTT and assume it is
+ # the same for all other MTTs.
+ mtt_captured_name_scope = threads[0].captured_name_scope
+ with ops.name_scope(mtt_captured_name_scope):
+ merge_result = threads[0].merge_fn(distribution, *merge_args,
+ **merge_kwargs)
+ for t in threads:
+ t.merge_result = values.select_device(t.device, merge_result)
+ finally:
+ for t in threads:
+ t.should_run.set()
+ coord.join(threads)
+
+ return values.regroup({t.device: t.main_result for t in threads})
+
+
+def _reduce_non_distributed_value(distribution, aggregation, value,
+ destinations):
+ """Reduce a non-DistributedValue `value` to `destinations`."""
+ if isinstance(value, values.DistributedValues):
+ raise ValueError("You are passing a `DistributedValue` to "
+ "`_reduce_non_distributed_value`, which is not allowed.")
+
+ if value == 0:
+ return 0
+ if aggregation == variable_scope.VariableAggregation.MEAN:
+ return distribution.broadcast(value, destinations)
+
+ cross_tower_ops_lib.validate_destinations(destinations)
+ if (len(distribution.worker_devices) != 1 or
+ not cross_tower_ops_lib.check_destinations(destinations)):
+ raise ValueError("A non-DistributedValues value cannot be reduced with the "
+ "given aggregation.")
+ # TODO(anjalisridhar): Moves these methods to a device utility file?
+ devices = cross_tower_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ with ops.device(devices[0]):
+ return array_ops.identity(value)
+ else:
+ value_updates = {}
+ for d in devices:
+ with ops.device(d):
+ value_updates[d] = array_ops.identity(value)
+ return values.Mirrored(value_updates)
+
+
class MirroredStrategy(distribute_lib.DistributionStrategy):
"""Mirrors vars to distribute across multiple devices on a single machine.
@@ -198,116 +347,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._devices)
def _call_for_each_tower(self, fn, *args, **kwargs):
- """Run `fn` in separate threads, once per tower/worker device.
-
- Args:
- fn: function to run (will be run once per device, each in its own thread).
- *args: positional arguments for `fn`
- **kwargs: keyword arguments for `fn`.
- `"run_concurrently"`: Boolean indicating whether executions of `fn`
- can be run concurrently (under eager execution only), defaults to
- `True`.
-
- Returns:
- Merged return value of `fn` across all towers.
-
- Raises:
- RuntimeError: If fn() calls get_tower_context().merge_call() a different
- number of times for when called for different devices.
- """
- run_concurrently = kwargs.pop("run_concurrently", True)
- if not context.executing_eagerly():
- # Lots of TF library code isn't thread-safe in graph mode, and
- # there is little to be gained by turning on multithreading when
- # constructing a graph.
- run_concurrently = False
- # Needed for per-thread device, etc. contexts in graph mode.
- ops.get_default_graph().switch_to_thread_local()
- elif run_concurrently is None:
- run_concurrently = True
-
- coord = coordinator.Coordinator(
- clean_stop_exception_types=(_RequestedStop,))
-
- shared_variable_store = {}
-
- # TODO(isaprykin): Create these threads once instead of during every run()
- # call.
- threads = []
- for index, d in enumerate(self._devices):
- variable_creator_fn = shared_variable_creator.make_fn(
- shared_variable_store, index)
- t = MirroredStrategy._MirroredTowerThread(
- self, coord, d, variable_creator_fn, fn,
- *values.select_device(d, args), **values.select_device(d, kwargs))
- threads.append(t)
-
- for t in threads:
- t.start()
-
- # When `fn` starts `should_run` event is set on _MirroredTowerThread
- # (`MTT`) threads. The execution waits until
- # `MTT.has_paused` is set, which indicates that either `fn` is
- # complete or a `get_tower_context().merge_call()` is called. If `fn` is
- # complete, then `MTT.done` is set to True. Otherwise, arguments
- # of `get_tower_context().merge_call` from all paused threads are grouped
- # and the `merge_fn` is performed. Results of the
- # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
- # Each such `get_tower_context().merge_call` call returns the
- # `MTT.merge_result` for that thread when `MTT.should_run` event
- # is reset again. Execution of `fn` resumes.
-
- try:
- with coord.stop_on_exception():
- all_done = False
- while not all_done and not coord.should_stop():
- done = []
- if run_concurrently:
- for t in threads:
- t.should_run.set()
- for t in threads:
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- else:
- for t in threads:
- t.should_run.set()
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- if coord.should_stop():
- return None
- all_done = all(done)
- if not all_done:
- if any(done):
- raise RuntimeError("Some towers made a different number of "
- "tower_context().merge_call() calls.")
- # get_tower_context().merge_call() case
- merge_args = values.regroup(
- {t.device: t.merge_args for t in threads})
- merge_kwargs = values.regroup(
- {t.device: t.merge_kwargs for t in threads})
- # We capture the name_scope of the MTT when we call merge_fn
- # to ensure that if we have opened a name scope in the MTT,
- # it will be respected when executing the merge function. We only
- # capture the name_scope from the first MTT and assume it is
- # the same for all other MTTs.
- mtt_captured_name_scope = threads[0].captured_name_scope
- with ops.name_scope(mtt_captured_name_scope):
- merge_result = threads[0].merge_fn(
- self, *merge_args, **merge_kwargs)
- for t in threads:
- t.merge_result = values.select_device(t.device, merge_result)
- finally:
- for t in threads:
- t.should_run.set()
- coord.join(threads)
-
- return values.regroup({t.device: t.main_result for t in threads})
+ return _call_for_each_tower(self, fn, *args, **kwargs)
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
@@ -337,29 +377,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
- if not isinstance(value, values.PerDevice):
- if value == 0:
- return 0
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return self._broadcast(value, destinations)
-
- cross_tower_ops_lib.validate_destinations(destinations)
- if len(self._devices) == 1:
- if destinations:
- # TODO(anjalisridhar): Moves these methods to a device utility file?
- devices = cross_tower_ops_lib.get_devices_from(destinations)
- if len(devices) == 1:
- with ops.device(devices[0]):
- return array_ops.identity(value)
- else:
- value_updates = {}
- for d in devices:
- with ops.device(d):
- value_updates[d] = array_ops.identity(value)
- return values.Mirrored(value_updates)
- raise ValueError("A non PerDevice value cannot be reduced with the given "
- "aggregation.")
-
+ if not isinstance(value, values.DistributedValues):
+ return _reduce_non_distributed_value(self, aggregation, value,
+ destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
@@ -433,15 +453,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _get_devices_from(self, colocate_with=None):
if colocate_with is None:
return self._devices
- elif isinstance(colocate_with, values.DistributedValues):
- # pylint: disable=protected-access
- return list(colocate_with._index.keys())
- elif isinstance(colocate_with, six.string_types):
- return [device_util.resolve(colocate_with)]
- elif isinstance(colocate_with, list):
- return [device_util.resolve(d) for d in colocate_with]
else:
- return colocate_with
+ return cross_tower_ops_lib.get_devices_from(colocate_with)
class _MirroredTowerThread(threading.Thread):
"""A thread that runs() a function on a device."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9807ce4351..aab7119901 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -792,8 +792,8 @@ class MirroredVariableUpdateTest(test.TestCase):
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
- ValueError, "A non PerDevice value cannot be reduced with the given "
- "aggregation."):
+ ValueError, "A non-DistributedValues value cannot be reduced with "
+ "the given aggregation."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index f659be5f42..fa479918bd 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -28,23 +28,39 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
+def create_in_process_cluster(num_workers, num_ps):
+ """Create an in-process cluster that consists of only standard server."""
+ # Leave some memory for cuda runtime.
+ gpu_mem_frac = 0.7 / num_workers
+ worker_config = config_pb2.ConfigProto()
+ worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ ps_config = config_pb2.ConfigProto()
+ ps_config.device_count['GPU'] = 0
+
+ # Create in-process servers. Once an in-process tensorflow server is created,
+ # there is no way to terminate it. So we create one cluster per test process.
+ # We could've started the server in another process, we could then kill that
+ # process to terminate the server. The reasons why we don't want multiple
+ # processes are
+ # 1) it is more difficult to manage these processes
+ # 2) there is something global in CUDA such that if we initialize CUDA in the
+ # parent process, the child process cannot initialize it again and thus cannot
+ # use GPUs (https://stackoverflow.com/questions/22950047).
+ return test_util.create_local_cluster(
+ num_workers,
+ num_ps=num_ps,
+ worker_config=worker_config,
+ ps_config=ps_config)
+
+
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- num_workers = 2
- # Leave some memory for cuda runtime.
- gpu_mem_frac = 0.7 / num_workers
- default_config = config_pb2.ConfigProto()
- default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
- # The local cluster takes some portion of the local GPUs and there is no way
- # for the cluster to terminate unless using multiple processes. Therefore,
- # we have to only create only one cluster throughout a test process.
- workers, _ = test_util.create_local_cluster(
- num_workers, num_ps=0, worker_config=default_config)
+ workers, _ = create_in_process_cluster(num_workers=2, num_ps=0)
cls._master_target = workers[0].target
@contextlib.contextmanager
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
new file mode 100644
index 0000000000..9bcf6f8bac
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -0,0 +1,355 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes implementing a multi-worker ps DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import device_setter
+from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import server_lib
+from tensorflow.python.util import nest
+
+_LOCAL_CPU = "/device:CPU:0"
+_LOCAL_GPU_0 = "/device:GPU:0"
+
+
+def _normalize_cluster_spec(cluster_spec):
+ """Makes `cluster_spec` into a `ClusterSpec` object."""
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+# TODO(yuefengz): maybe cache variables on local CPU.
+# TODO(yuefengz): we may want to set session options to disallow communication
+# between workers.
+class ParameterServerStrategy(distribute_lib.DistributionStrategy):
+ """A parameter server DistributionStrategy.
+
+ This strategy class works for both local training and between-graph replicated
+ training for multiple workers. If `cluster_spec` is specified, either passed
+ in to __init__() method or parsed from the
+ ["TF_CONFIG" environment
+ variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig),
+ variables and updates to those variables are assigned to parameter servers and
+ other operations are assigned to workers. If `cluster_spec` is not set, it
+ becomes local training where variables are assigned to local CPU or the only
+ GPU. When each worker has more than one GPU, operations will be replicated on
+ these GPUs. In both cases, operations are replicated but variables are not and
+ these workers share a common view for which paramater server a variable is
+ assigned to.
+
+ This class assumes between-graph replication will be used and works on a graph
+ for a particular worker.
+
+ It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any
+ operations which potentially can be replicated across towers (i.e. multiple
+ GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
+ caution needs to be taken:
+
+ 1) Always use @{tf.get_variable} instead of @{tf.Variable} which is not able
+ to refer to the same variable on different towers.
+
+ 2) It is generally not recommended to open a device scope under the strategy's
+ scope. A device scope (i.e. calling @{tf.device}) will be merged with or
+ override the device for operations but will not change the device for
+ variables.
+
+ 3) It is also not recommended to open a colocation scope (i.e. calling
+ @{tf.colocate_with}) under the strategy's scope. For colocating variables,
+ use `distribution.colocate_vars_with` instead. Colocation of ops will possibly
+ create conflicts of device assignement.
+ """
+
+ def __init__(self,
+ num_gpus_per_worker=0,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Initiailizes this strategy.
+
+ Args:
+ num_gpus_per_worker: number of local GPUs or GPUs per worker.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type.
+ task_id: the current task id.
+ """
+ super(ParameterServerStrategy, self).__init__()
+ self._num_gpus_per_worker = num_gpus_per_worker
+ if cluster_spec:
+ cluster_spec = _normalize_cluster_spec(cluster_spec)
+ self._cluster_spec = cluster_spec
+
+ # We typically don't need to do all-reduce in this strategy.
+ self._cross_tower_ops = (
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ reduce_to_device=_LOCAL_CPU))
+
+ self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type,
+ task_id)
+
+ def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type,
+ task_id):
+ """Initialize internal devices.
+
+ It creates variable devices and compute devices. Variables and operations
+ will be assigned to them respectively. We have one compute device per tower.
+ The variable device is a device function or device string. The default
+ variable device assigns variables to parameter servers in a round-robin
+ fashion.
+
+ Args:
+ num_gpus_per_worker: number of local GPUs or GPUs per worker.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type.
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if the cluster_spec doesn't have ps jobs.
+ """
+ self._task_type = task_type or "worker"
+ self._task_id = task_id or 0
+ self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
+
+ # TODO(yuefengz): maybe clearer to split it into two classes, one for
+ # the distribuetd case and one for the local case, once we have the factory
+ # class/method.
+
+ # Define compute devices which is a list of device strings and one for each
+ # tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
+ # place operations on CPU.
+ if cluster_spec is None:
+ # Local mode.
+ if num_gpus_per_worker > 0:
+ self._compute_devices = list(
+ map("/device:GPU:{}".format, range(num_gpus_per_worker)))
+ else:
+ self._compute_devices = [_LOCAL_CPU]
+ else:
+ # Distributed mode.
+ if num_gpus_per_worker > 0:
+ self._compute_devices = [
+ "%s/device:GPU:%d" % (self._worker_device, i)
+ for i in range(num_gpus_per_worker)
+ ]
+ else:
+ self._compute_devices = [self._worker_device]
+
+ self._compute_devices = list(
+ map(device_util.resolve, self._compute_devices))
+ self._canonical_compute_device_set = set(self._compute_devices)
+
+ # Define variable device which is a device string in the local case and a
+ # device function in the distributed case. It is used to open a device scope
+ # where varibles are defined.
+ # The `_parameter_devices` is needed for the `parameter_devices` property
+ # and is a list of all variable devices.
+ if cluster_spec is None:
+ # Local mode. If there is only one GPU, put everything on that GPU.
+ # Otherwise, place variables on CPU.
+ if num_gpus_per_worker == 1:
+ assert len(list(self._compute_devices)) == 1
+ self._variable_device = _LOCAL_GPU_0
+ self._parameter_devices = [_LOCAL_GPU_0]
+ else:
+ self._variable_device = _LOCAL_CPU
+ self._parameter_devices = [_LOCAL_CPU]
+ else:
+ # Distributed mode. Place variables on ps jobs in a round-robin fashion.
+ # Note that devices returned from `replica_device_setter` are not
+ # canonical and therefore we don't canonicalize all variable devices to
+ # make them consistent.
+ # TODO(yuefengz): support passing a strategy object to control variable
+ # assignment.
+ # TODO(yuefengz): merge the logic of replica_device_setter into this
+ # class.
+ num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
+ if num_ps_replicas == 0:
+ raise ValueError("The cluster spec needs to have `ps` jobs.")
+ self._variable_device = device_setter.replica_device_setter(
+ ps_tasks=num_ps_replicas,
+ worker_device=self._worker_device,
+ merge_devices=True,
+ cluster=cluster_spec)
+
+ # Parameter devices are all tasks of the "ps" job.
+ self._parameter_devices = map("/job:ps/task:{}".format,
+ range(num_ps_replicas))
+
+ # Define the default device in cross-tower mode. In the distributed case, we
+ # set the default device to the corresponding worker to prevent these ops
+ # from being placed on other workers.
+ if cluster_spec is None:
+ self._default_device = None
+ else:
+ self._default_device = self._worker_device
+
+ def distribute_dataset(self, dataset_fn):
+ """Distributes the dataset to each local GPU."""
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._compute_devices, True)
+
+ def _broadcast(self, tensor, destinations):
+ if not cross_tower_ops_lib.check_destinations(destinations):
+ destinations = self._compute_devices
+ return self._cross_tower_ops.broadcast(tensor, destinations)
+
+ # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
+ # this creator, such as "MutableHashTable".
+ def _create_variable(self, next_creator, *args, **kwargs):
+ if "colocate_with" in kwargs:
+ with ops.device(None):
+ with ops.colocate_with(kwargs["colocate_with"]):
+ return next_creator(*args, **kwargs)
+
+ with ops.colocate_with(None, ignore_existing=True):
+ with ops.device(self._variable_device):
+ return next_creator(*args, **kwargs)
+
+ def _call_for_each_tower(self, fn, *args, **kwargs):
+ # pylint: disable=protected-access
+ return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs)
+
+ def _verify_destinations_not_different_worker(self, destinations):
+ if destinations is None:
+ return
+ for d in cross_tower_ops_lib.get_devices_from(destinations):
+ d_spec = tf_device.DeviceSpec.from_string(d)
+ if d_spec.job == self._task_type and d_spec.task != self._task_id:
+ raise ValueError(
+ "Cannot reduce to another worker: %r, current worker is %r" %
+ (d, self._worker_device))
+
+ def _reduce(self, aggregation, value, destinations):
+ self._verify_destinations_not_different_worker(destinations)
+ if not isinstance(value, values.DistributedValues):
+ # pylint: disable=protected-access
+ return mirrored_strategy._reduce_non_distributed_value(
+ self, aggregation, value, destinations)
+
+ return self._cross_tower_ops.reduce(
+ aggregation, value, destinations=destinations)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ for _, destinations in value_destination_pairs:
+ self._verify_destinations_not_different_worker(destinations)
+ return self._cross_tower_ops.batch_reduce(aggregation,
+ value_destination_pairs)
+
+ def _select_single_value(self, structured):
+ """Select any single values in `structured`."""
+
+ def _select_fn(x): # pylint: disable=g-missing-docstring
+ if isinstance(x, values.Mirrored):
+ if len(x.devices) == 1:
+ return list(x._index.values())[0] # pylint: disable=protected-access
+ else:
+ raise ValueError(
+ "You cannot update variable with a Mirrored object with multiple "
+ "components %r when using ParameterServerStrategy. You must "
+ "specify a single value or a Mirrored with a single value." % x)
+ elif isinstance(x, values.PerDevice):
+ raise ValueError(
+ "You cannot update variable with a PerDevice object %r when using "
+ "ParameterServerStrategy. You must specify a single value or a "
+ "Mirrored with a single value" % x)
+ else:
+ return x
+
+ return nest.map_structure(_select_fn, structured)
+
+ def _update(self, var, fn, *args, **kwargs):
+ if not isinstance(var, resource_variable_ops.ResourceVariable):
+ raise ValueError(
+ "You can not update `var` %r. It must be a Variable." % var)
+ with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
+ return fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+
+ # TODO(yuefengz): does it need to call _select_single_value?
+ def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ with ops.device(
+ colocate_with.device), distribute_lib.UpdateContext(colocate_with):
+ return fn(*args, **kwargs)
+
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ if set(val.devices) == self._canonical_compute_device_set:
+ return [val.get(device=d) for d in self._compute_devices]
+ return [val.get(device=d) for d in sorted(val.devices)]
+ return [val]
+
+ def read_var(self, var):
+ # No need to distinguish between normal variables and tower-local variables.
+ return array_ops.identity(var)
+
+ def configure(self, session_config=None):
+ del session_config
+
+ # Use TF_CONFIG to get the cluster spec and the current job.
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", "worker")
+ task_id = int(task_env.get("index", "0"))
+ else:
+ task_type = "worker"
+ task_id = None
+
+ # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in
+ # the constructor.
+ if not self._cluster_spec and cluster_spec:
+ self._cluster_spec = cluster_spec
+ self._initialize_devices(self._num_gpus_per_worker, cluster_spec,
+ task_type, task_id)
+
+ @property
+ def num_towers(self):
+ return len(self._compute_devices)
+
+ @property
+ def worker_devices(self):
+ # Make a copy to prevent users from accidentally mutating our copy.
+ return list(self._compute_devices)
+
+ @property
+ def parameter_devices(self):
+ return list(self._parameter_devices)
+
+ def non_slot_devices(self, var_list):
+ return min(var_list, key=lambda x: x.name)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
new file mode 100644
index 0000000000..ad538b9e8e
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -0,0 +1,455 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ParameterServerStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import json
+import threading
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import multi_worker_test_base
+from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
+
+
+class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2)
+
+ def setUp(self):
+ self._result = 0
+ self._lock = threading.Lock()
+ self._init_condition = threading.Condition()
+ self._init_reached = 0
+ self._finish_condition = threading.Condition()
+ self._finish_reached = 0
+
+ def _get_ps_distribution_strategy(self, task_type, task_index, num_gpus=0):
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.WORKER: [
+ 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
+ ],
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ },
+ 'task': {
+ 'type': task_type,
+ 'index': task_index
+ }
+ }
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=num_gpus)
+ with self._lock:
+ # Accessing environment variables should be protected by locks because
+ # environment variables are shared by all threads.
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ distribution.configure()
+ return distribution
+
+ @contextlib.contextmanager
+ def _test_session(self, target):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ config.graph_options.optimizer_options.opt_level = -1
+ with session.Session(graph=None, config=config, target=target) as sess:
+ yield sess
+
+ def _test_device_assignment_distributed(self, d, num_gpus=0):
+ with ops.Graph().as_default(), \
+ self._test_session(target=self._workers[0].target) as sess, \
+ d.scope():
+
+ # Define a variable outside the call_for_each_tower scope. This is not
+ # recommended.
+ n = variable_scope.get_variable('n', initializer=10.0)
+ self.assertEqual(n.device, '/job:ps/task:0')
+
+ def model_fn():
+ if num_gpus == 0:
+ last_part_device = 'device:CPU:0'
+ else:
+ last_part_device = (
+ 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+ self.assertEqual(a.device,
+ '/job:worker/replica:0/task:1/%s' % last_part_device)
+ self.assertEqual(b.device,
+ '/job:worker/replica:0/task:1/%s' % last_part_device)
+ self.assertEqual(c.device,
+ '/job:worker/replica:0/task:1/%s' % last_part_device)
+
+ # The device scope is ignored for variables but not for normal ops.
+ with ops.device('/job:worker/task:0'):
+ x = variable_scope.get_variable('x', initializer=10.0)
+ x_add = x.assign_add(c)
+ e = a + c
+ # The variable x is on the task 1 since the device_function has been
+ # called once before the model_fn.
+ self.assertEqual(x.device, '/job:ps/task:1')
+ self.assertEqual(x_add.device, x.device)
+ self.assertEqual(e.device,
+ '/job:worker/replica:0/task:0/%s' % last_part_device)
+
+ # The colocate_vars_with can override the distribution's device.
+ with d.colocate_vars_with(x):
+ y = variable_scope.get_variable('y', initializer=20.0)
+ y_add = y.assign_add(x_add)
+ self.assertEqual(y.device, '/job:ps/task:1')
+ self.assertEqual(y_add.device, y.device)
+ self.assertEqual(y.device, x.device)
+
+ z = variable_scope.get_variable('z', initializer=10.0)
+ self.assertEqual(z.device, '/job:ps/task:0')
+ self.assertNotEqual(z.device, x.device)
+
+ with ops.control_dependencies([y_add]):
+ z_add = z.assign_add(y)
+ with ops.control_dependencies([z_add]):
+ f = z + c
+ self.assertEqual(f.device,
+ '/job:worker/replica:0/task:1/%s' % last_part_device)
+
+ # The device scope would merge with the default worker device.
+ with ops.device('/CPU:1'):
+ g = e + 1.0
+ self.assertEqual(g.device, '/job:worker/replica:0/task:1/device:CPU:1')
+
+ # Ths ops.colocate_with will be ignored when defining a variale but not
+ # for a normal tensor.
+ with ops.colocate_with(x):
+ u = variable_scope.get_variable('u', initializer=30.0)
+ v = variable_scope.get_variable('v', initializer=30.0)
+ h = f + 1.0
+ self.assertIn('/job:ps/', u.device)
+ self.assertIn('/job:ps/', v.device)
+ # u and v are on different parameter servers.
+ self.assertTrue(u.device != x.device or v.device != x.device)
+ self.assertTrue(u.device == x.device or v.device == x.device)
+ # Here h is not on one worker. Note h.device is canonical while x.device
+ # is not but.
+ self.assertIn('/job:ps/', h.device)
+ return y_add, z_add, f
+
+ y, z, f = d.call_for_each_tower(model_fn)
+ self.assertNotEqual(y, None)
+ self.assertNotEqual(z, None)
+ self.assertNotEqual(f, None)
+
+ if context.num_gpus() >= 1 and num_gpus <= 1:
+ variables.global_variables_initializer().run()
+ y_val, z_val, f_val = sess.run([y, z, f])
+ self.assertEqual(y_val, 33.0)
+ self.assertEqual(z_val, 43.0)
+ self.assertEqual(f_val, 46.0)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testDeviceAssignmentDistributed(self, num_gpus):
+ d = self._get_ps_distribution_strategy('worker', 1, num_gpus=num_gpus)
+ self._test_device_assignment_distributed(d, num_gpus=num_gpus)
+
+ def _test_device_assignment_local(self,
+ d,
+ compute_device='CPU',
+ variable_device='CPU',
+ num_gpus=0):
+ with ops.Graph().as_default(), \
+ self._test_session(target=self._workers[0].target) as sess, \
+ d.scope():
+
+ def model_fn():
+ if 'CPU' in compute_device:
+ tower_compute_device = '/device:CPU:0'
+ else:
+ tower_compute_device = (
+ '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ tower_compute_device = device_util.canonicalize(tower_compute_device)
+
+ if 'CPU' in variable_device:
+ tower_variable_device = '/device:CPU:0'
+ else:
+ tower_variable_device = (
+ '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ tower_variable_device = device_util.canonicalize(tower_variable_device)
+
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+ self.assertEqual(a.device, tower_compute_device)
+ self.assertEqual(b.device, tower_compute_device)
+ self.assertEqual(c.device, tower_compute_device)
+
+ # The device scope is ignored for variables but not for normal ops.
+ with ops.device('/device:GPU:2'):
+ x = variable_scope.get_variable('x', initializer=10.0)
+ x_add = x.assign_add(c)
+ e = a + c
+ self.assertEqual(
+ device_util.canonicalize(x.device), tower_variable_device)
+ self.assertEqual(x_add.device, x.device)
+ self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2'))
+
+ # The colocate_vars_with can override the distribution's device.
+ with d.colocate_vars_with(x):
+ y = variable_scope.get_variable('y', initializer=20.0)
+ y_add = y.assign_add(x_add)
+ self.assertEqual(
+ device_util.canonicalize(y.device), tower_variable_device)
+ self.assertEqual(y_add.device, y.device)
+ self.assertEqual(y.device, x.device)
+
+ z = variable_scope.get_variable('z', initializer=10.0)
+ self.assertEqual(
+ device_util.canonicalize(z.device), tower_variable_device)
+
+ with ops.control_dependencies([y_add]):
+ z_add = z.assign_add(y)
+ with ops.control_dependencies([z_add]):
+ f = z + c
+ self.assertEqual(f.device, tower_compute_device)
+
+ # The device scope would merge with the default worker device.
+ with ops.device('/CPU:1'):
+ g = e + 1.0
+ self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))
+
+ # Ths ops.colocate_with will be ignored when defining a variale but not
+ # for a normal tensor.
+ with ops.colocate_with(x):
+ u = variable_scope.get_variable('u', initializer=30.0)
+ h = f + 1.0
+ self.assertEqual(
+ device_util.canonicalize(u.device), tower_variable_device)
+ self.assertEqual(device_util.canonicalize(x.device), h.device)
+ return y_add, z_add, f
+
+ y, z, f = d.call_for_each_tower(model_fn)
+ self.assertNotEqual(y, None)
+ self.assertNotEqual(z, None)
+ self.assertNotEqual(f, None)
+
+ if context.num_gpus() >= 1 and num_gpus <= 1:
+ variables.global_variables_initializer().run()
+ y_val, z_val, f_val = sess.run([y, z, f])
+ self.assertEqual(y_val, 33.0)
+ self.assertEqual(z_val, 43.0)
+ self.assertEqual(f_val, 46.0)
+
+ def testDeviceAssignmentLocal(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=0)
+ self._test_device_assignment_local(
+ distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=1)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
+
+ def _test_simple_increment(self, d, task_type, task_index, master_target):
+ if hasattr(d, '_cluster_spec') and d._cluster_spec:
+ num_workers = len(d._cluster_spec.as_dict().get('worker',
+ ['dummy_worker']))
+ else:
+ num_workers = 1
+ with ops.Graph().as_default(), \
+ self._test_session(target=master_target) as sess, \
+ d.scope():
+
+ def model_fn():
+ x = variable_scope.get_variable('x', initializer=10.0)
+ y = variable_scope.get_variable('y', initializer=20.0)
+
+ x_add = x.assign_add(1.0, use_locking=True)
+ y_add = y.assign_add(1.0, use_locking=True)
+
+ train_op = control_flow_ops.group([x_add, y_add])
+ return x, y, train_op
+
+ x, y, train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(d.unwrap(train_op))
+
+ if context.num_gpus() < d._num_gpus_per_worker:
+ return True
+
+ if task_index == 0:
+ variables.global_variables_initializer().run()
+
+ # Workers waiting for chief worker's initializing variables.
+ self._init_condition.acquire()
+ self._init_reached += 1
+ while self._init_reached != num_workers:
+ self._init_condition.wait()
+ self._init_condition.notify_all()
+ self._init_condition.release()
+
+ sess.run(train_op)
+
+ # Wait for other workers to finish training.
+ self._finish_condition.acquire()
+ self._finish_reached += 1
+ while self._finish_reached != num_workers:
+ self._finish_condition.wait()
+ self._finish_condition.notify_all()
+ self._finish_condition.release()
+
+ x_val, y_val = sess.run([x, y])
+ self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers)
+ self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers)
+ return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
+ y_val == 20.0 + 1.0 * num_workers * d.num_towers)
+
+ def _test_minimize_loss_graph(self, d, task_type, task_index, master_target):
+ with ops.Graph().as_default(), \
+ self._test_session(target=master_target) as sess, \
+ d.scope():
+ l = core.Dense(1, use_bias=False)
+
+ def loss_fn(x):
+ y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
+ return y * y
+
+ # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
+ # multiple graphs (b/111216820).
+ def grad_fn(x):
+ loss = loss_fn(x)
+ var_list = (
+ variables.trainable_variables() + ops.get_collection(
+ ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ grads = gradients.gradients(loss, var_list)
+ ret = list(zip(grads, var_list))
+ return ret
+
+ def update(v, g):
+ return v.assign_sub(0.05 * g, use_locking=True)
+
+ one = d.broadcast(constant_op.constant([[1.]]))
+
+ def step():
+ """Perform one optimization step."""
+ # Run forward & backward to get gradients, variables list.
+ g_v = d.call_for_each_tower(grad_fn, one)
+ # Update the variables using the gradients and the update() function.
+ before_list = []
+ after_list = []
+ for g, v in g_v:
+ fetched = d.read_var(v)
+ before_list.append(fetched)
+ with ops.control_dependencies([fetched]):
+ # TODO(yuefengz): support non-Mirrored variable as destinations.
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
+ with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ after_list.append(d.read_var(v))
+ return before_list, after_list
+
+ before_out, after_out = step()
+
+ if context.num_gpus() < d._num_gpus_per_worker:
+ return True
+
+ if task_index == 0:
+ variables.global_variables_initializer().run()
+
+ # Workers waiting for chief worker's initializing variables.
+ self._init_condition.acquire()
+ self._init_reached += 1
+ while self._init_reached != 3:
+ self._init_condition.wait()
+ self._init_condition.notify_all()
+ self._init_condition.release()
+
+ for i in range(10):
+ b, a = sess.run((before_out, after_out))
+ if i == 0:
+ before, = b
+ after, = a
+
+ error_before = abs(before - 1)
+ error_after = abs(after - 1)
+ # Error should go down
+ self.assertLess(error_after, error_before)
+ return error_after < error_before
+
+ def _run_client(self, index, model_fn, num_gpus):
+ task_type = run_config.TaskType.WORKER
+ result = model_fn(
+ self._get_ps_distribution_strategy(task_type, index, num_gpus=num_gpus),
+ task_type, index, self._workers[index].target)
+ if result:
+ with self._lock:
+ self._result += 1
+
+ def _run_multiple_clients(self, num_clients, model_fn, num_gpus=0):
+ threads = []
+ for i in range(num_clients):
+ t = threading.Thread(
+ target=self._run_client, args=(i, model_fn, num_gpus))
+ t.start()
+ threads.append(t)
+ for t in threads:
+ t.join()
+
+ def testSimpleBetweenGraph(self):
+ self._run_multiple_clients(3, self._test_simple_increment)
+ self.assertEqual(self._result, 3)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testLocalSimpleIncrement(self, num_gpus):
+ d = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=num_gpus)
+ self._test_simple_increment(d, 'dummy_worker', 0, '')
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_multiple_clients(
+ 3, self._test_minimize_loss_graph, num_gpus=num_gpus)
+ self.assertEqual(self._result, 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 47dcf679c2..4018b1e023 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -210,6 +210,11 @@ class DistributedVariable(DistributedDelegate):
# without this it will use `__getattr__` which will delegate to a component
# variable.
self._keras_initialized = False
+ # Typically, a `DistributedVariable`'s initializer is composed of the
+ # initializers of the components variables. However, in some cases, such as
+ # when restoring from a checkpoint, we may set the _initializer_op
+ # property on the entire `DistributedVariable`.
+ self._initializer_op = None
super(DistributedVariable, self).__init__(index)
def is_initialized(self, name=None):
@@ -239,9 +244,14 @@ class DistributedVariable(DistributedDelegate):
@property
def initializer(self):
- # return grouped ops of all the var initializations of component values of
- # the mirrored variable
- return control_flow_ops.group([v.initializer for v in self._index.values()])
+ if self._initializer_op:
+ init_op = self._initializer_op
+ else:
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
+ init_op = control_flow_ops.group(
+ [v.initializer for v in self._index.values()])
+ return init_op
@property
def graph(self):
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index ef3bdfa75f..18a0f754e6 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -326,6 +326,21 @@ class QuantizedDistribution(distributions.Distribution):
graph_parents=graph_parents,
name=name)
+ @property
+ def distribution(self):
+ """Base distribution, p(x)."""
+ return self._dist
+
+ @property
+ def low(self):
+ """Lowest value that quantization returns."""
+ return self._low
+
+ @property
+ def high(self):
+ """Highest value that quantization returns."""
+ return self._high
+
def _batch_shape_tensor(self):
return self.distribution.batch_shape_tensor()
@@ -569,8 +584,3 @@ class QuantizedDistribution(distributions.Distribution):
dependencies = [distribution_util.assert_integer_form(
value, message="value has non-integer components.")]
return control_flow_ops.with_dependencies(dependencies, value)
-
- @property
- def distribution(self):
- """Base distribution, p(x)."""
- return self._dist
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/README.md b/tensorflow/contrib/eager/python/examples/l2hmc/README.md
index d6a2ff7558..f171806e37 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/README.md
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/README.md
@@ -4,16 +4,15 @@ This folder contains an implementation of [L2HMC](https://arxiv.org/pdf/1711.092
With eager execution enabled, longer sample chains can be handled compared to graph mode, since no graph is explicitly stored. Moreover, with eager execution enabled, there is no need to use a `tf.while_loop`.
## What is L2HMC?
-L2HMC is an algorithm that learns a non-volume preserving transformation
-for an HMC-like sampling algorithm. More specifically, the non-volume preserving
+L2HMC is an adaptive Markov Chain Monte Carlo (MCMC) algorithm that learns a non-volume preserving transformation
+for a Hamiltonian Monte Carlo (HMC) sampling algorithm. More specifically, the non-volume preserving
transformation is learned with neural nets instantiated within Normalizing Flows
-(more precisely, real-NVPs).
+(real-NVPs).
## Content
- `l2hmc.py`: Dynamics definitions and example energy functions,
-including the 2D strongly correlated Gaussian, the rough well energy function,
-and a Gaussian mixture model.
+including the 2D strongly correlated Gaussian and the rough well energy function,
- `l2hmc_test.py`: Unit tests and benchmarks for training a sampler on the energy functions in both eager and graph mode.
- `neural_nets.py`: The neural net for learning the kernel on the 2D strongly correlated example.
- `main.py`: Run to train a samplers on 2D energy landscapes.
@@ -32,7 +31,7 @@ tensorboard and a plot of sampled chain from the trained sampler.
Specifying the optional argument `use_defun` will let the program use compiled
graphs when running specific sections and improve the overall speed.
-## Boosting Performance with `defun`
+## Boosting Performance with `tfe.defun`
Currently, some models may experience increased overhead with eager execution enabled.
To improve performance, we could wrap certain functions with the decorator `@tfe.defun`.
For example, we could wrap the function that does the sampling step:
diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md
index 21fc44febc..2875d0ffb3 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/README.md
+++ b/tensorflow/contrib/eager/python/examples/revnet/README.md
@@ -1,18 +1,21 @@
# RevNet with TensorFlow eager execution
-This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster.
+This folder contains a TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster.
## Content
- `revnet.py`: The RevNet model.
- `blocks.py`: The relevant reversible blocks.
+- `ops.py`: Auxiliary downsampling operation.
- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100.
- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API.
- `config.py`: Configuration file for network architectures and training hyperparameters.
- `main.py`: Main training and evaluation script.
-- `ops.py`: Auxiliary downsampling operation.
+- `main_estimator.py`: Script to train RevNet models on CIFAR-10 and CIFAR-100 with the `tf.estimator` API.
+- `main_estimator_tpu.py`: Script to train RevNet models on ImageNet with TPU estimators on Cloud TPUs.
+- `resnet_preprocessing.py`, `imagenet_input.py`: Boilerplate to read ImageNet data from TFRecords.
-## To run
+## Train on CIFAR-10/CIFAR-100
- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly`
or `tf-nightly-gpu` pip package in order to access the eager execution feature.
@@ -24,7 +27,7 @@ python cifar_tfrecords.py --data_dir ${PWD}/cifar
to download the cifar dataset and convert them
to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100.
-- To train a model run
+- To train a model, run
```bash
python main.py --data_dir ${PWD}/cifar
@@ -34,8 +37,63 @@ python main.py --data_dir ${PWD}/cifar
- `train_dir`: Directory to store eventfiles and checkpoints.
- `restore`: Restore the latest checkpoint.
- `validate`: Use validation set for training monitoring.
- - `manual_grad`: Use the manually defined gradient map given by the authors.
- - `dataset`: Use either `cifar-10` or `cifar-100`
+ - `dataset`: Use either `cifar-10` or `cifar-100`.
+ - `config`: RevNet configuration.
+ - `use_defun`: Use `tfe.defun` to boost performance.
+
+- To train a model with estimators in graph-mode, run
+
+```bash
+python main_estimator.py --data_dir ${PWD}/cifar
+```
+
+- Optional arguments for `main.py` include
+ - `model_dir`: Directory to store eventfiles and checkpoints.
+ - `dataset`: Use either `cifar-10` or `cifar-100`.
+ - `config`: RevNet configuration.
+ - `export`: Export the model for serving if True.
+
+## Speed up with `tfe.defun`
+Even though the speed difference between pure eager execution and graph-mode execution is noticeable,
+the difference between fully "defunned" model training and graph-mode
+training is negligible.
+
+## Train on ImageNet with Cloud TPUs
+The standard way to train models on Cloud TPUs is via TPU estimators and graph-mode
+execution. Models built with the `tf.keras` API are fully compatible with TPU estimators.
+
+### Setup a Google Cloud project
+
+Follow the instructions at the [Quickstart Guide](https://cloud.google.com/tpu/docs/quickstart)
+to get a GCE VM with access to Cloud TPU.
+
+To run this model, you will need:
+
+* A GCE VM instance with an associated Cloud TPU resource
+* A GCS bucket to store your training checkpoints
+* (Optional): The ImageNet training and validation data preprocessed into
+ TFRecord format, and stored in GCS.
+
+### Format the data
+
+The data is expected to be formatted in TFRecord format, as generated by [this
+script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py).
+
+If you do not have ImageNet dataset prepared, you can use a randomly generated
+fake dataset to test the model. It is located at
+`gs://cloud-tpu-test-datasets/fake_imagenet`.
+
+### Start training
+
+Train the model by executing the following command (substituting the appropriate
+values):
+
+```bash
+python main_estimator_tpu.py \
+ --tpu=$TPU_NAME \
+ --data_dir=$DATA_DIR \
+ --model_dir=$MODEL_DIR
+```
## Performance
- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100.
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index 8a530b0d71..f61354bc38 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -91,32 +91,21 @@ class RevBlock(tf.keras.Model):
h = block(h, training=training)
return h
- def backward_grads_and_vars(self, x, y, dy, training=True):
+ def backward_grads(self, x, y, dy, training=True):
"""Apply reversible block backward to outputs."""
grads_all = []
- vars_all = []
-
for i in reversed(range(len(self.blocks))):
block = self.blocks[i]
if i == 0:
# First block usually contains downsampling that can't be reversed
- with tf.GradientTape() as tape:
- tape.watch(x)
- y = block(x, training=training)
-
- grads_combined = tape.gradient(
- y, [x] + block.trainable_variables, output_gradients=dy)
- dy = grads_combined[0]
- grads_all += grads_combined[1:]
- vars_all += block.trainable_variables
+ dy, grads = block.backward_grads_with_downsample(
+ x, y, dy, training=True)
else:
- y, dy, grads, vars_ = block.backward_grads_and_vars(
- y, dy, training=training)
- grads_all += grads
- vars_all += vars_
+ y, dy, grads = block.backward_grads(y, dy, training=training)
+ grads_all = grads + grads_all
- return dy, grads_all, vars_all
+ return dy, grads_all
class _Residual(tf.keras.Model):
@@ -178,10 +167,9 @@ class _Residual(tf.keras.Model):
fused=fused,
dtype=dtype)
- def call(self, x, training=True, concat=True):
+ def call(self, x, training=True):
"""Apply residual block to inputs."""
-
- x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
+ x1, x2 = x
f_x2 = self.f(x2, training=training)
x1_down = ops.downsample(
x1, self.filters // 2, self.strides, axis=self.axis)
@@ -190,42 +178,81 @@ class _Residual(tf.keras.Model):
y1 = f_x2 + x1_down
g_y1 = self.g(y1, training=training)
y2 = g_y1 + x2_down
- if not concat: # For correct backward grads
- return y1, y2
- return tf.concat([y1, y2], axis=self.axis)
+ return y1, y2
- def backward_grads_and_vars(self, y, dy, training=True):
+ def backward_grads(self, y, dy, training=True):
"""Manually compute backward gradients given input and output grads."""
- dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
+ dy1, dy2 = dy
+ y1, y2 = y
- with tf.GradientTape(persistent=True) as tape:
- tape.watch(y)
- y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
+ with tf.GradientTape() as gtape:
+ gtape.watch(y1)
gy1 = self.g(y1, training=training)
+ grads_combined = gtape.gradient(
+ gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
+ dg = grads_combined[1:]
+ dx1 = dy1 + grads_combined[0]
+ # This doesn't affect eager execution, but improves memory efficiency with
+ # graphs
+ with tf.control_dependencies(dg + [dx1]):
x2 = y2 - gy1
+
+ with tf.GradientTape() as ftape:
+ ftape.watch(x2)
fx2 = self.f(x2, training=training)
+ grads_combined = ftape.gradient(
+ fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
+ df = grads_combined[1:]
+ dx2 = dy2 + grads_combined[0]
+ # Same behavior as above
+ with tf.control_dependencies(df + [dx2]):
x1 = y1 - fx2
- grads_combined = tape.gradient(
+ x = x1, x2
+ dx = dx1, dx2
+ grads = df + dg
+
+ return x, dx, grads
+
+ def backward_grads_with_downsample(self, x, y, dy, training=True):
+ """Manually compute backward gradients given input and output grads."""
+ # Splitting this from `backward_grads` for better readability
+ x1, x2 = x
+ y1, _ = y
+ dy1, dy2 = dy
+
+ with tf.GradientTape() as gtape:
+ gtape.watch(y1)
+ gy1 = self.g(y1, training=training)
+ grads_combined = gtape.gradient(
gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
dg = grads_combined[1:]
- dx1 = dy1 + grads_combined[0]
+ dz1 = dy1 + grads_combined[0]
- grads_combined = tape.gradient(
- fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
- dx2 = dy2 + grads_combined[0]
- df = grads_combined[1:]
+ # dx1 need one more step to backprop through downsample
+ with tf.GradientTape() as x1tape:
+ x1tape.watch(x1)
+ z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis)
+ dx1 = x1tape.gradient(z1, x1, output_gradients=dz1)
- del tape
+ with tf.GradientTape() as ftape:
+ ftape.watch(x2)
+ fx2 = self.f(x2, training=training)
+ grads_combined = ftape.gradient(
+ fx2, [x2] + self.f.trainable_variables, output_gradients=dz1)
+ dx2, df = grads_combined[0], grads_combined[1:]
- grads = df + dg
- vars_ = self.f.trainable_variables + self.g.trainable_variables
+ # dx2 need one more step to backprop through downsample
+ with tf.GradientTape() as x2tape:
+ x2tape.watch(x2)
+ z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis)
+ dx2 += x2tape.gradient(z2, x2, output_gradients=dy2)
- x = tf.concat([x1, x2], axis=self.axis)
- dx = tf.concat([dx1, dx2], axis=self.axis)
+ dx = dx1, dx2
+ grads = df + dg
- return x, dx, grads, vars_
+ return dx, grads
# Ideally, the following should be wrapped in `tf.keras.Sequential`, however
@@ -422,7 +449,7 @@ class InitBlock(tf.keras.Model):
if self.config.init_max_pool:
net = self.max_pool(net)
- return net
+ return tf.split(net, num_or_size_splits=2, axis=self.axis)
class FinalBlock(tf.keras.Model):
@@ -468,7 +495,7 @@ class FinalBlock(tf.keras.Model):
self.config.n_classes, dtype=self.config.dtype)
def call(self, x, training=True):
- net = x
+ net = tf.concat(x, axis=self.axis)
net = self.batch_norm(net, training=training)
net = self.activation(net)
net = self.global_avg_pool(net)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
index d74785c8fe..fda9020ddf 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -116,70 +116,13 @@ def _validate_block_call_channels_first(block_factory, test):
class RevBlockTest(tf.test.TestCase):
- def test_call_channels_first(self):
- """Test `call` function with `channels_first` data format."""
- if not tf.test.is_gpu_available():
- self.skipTest("GPU not available")
-
- with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (128, 8, 8)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
-
- # Stride of 1
- block = blocks.RevBlock(
- n_res=3, filters=128, strides=(1, 1), input_shape=input_shape)
- y_tr, y_ev = block(x, training=True), block(x, training=False)
- self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 128, 8, 8))
- self.assertNotAllClose(y_tr, y_ev)
-
- # Stride of 2
- block = blocks.RevBlock(
- n_res=3, filters=128, strides=(2, 2), input_shape=input_shape)
- y_tr, y_ev = block(x, training=True), block(x, training=False)
- self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, [16, 128, 4, 4])
- self.assertNotAllClose(y_tr, y_ev)
-
- def test_call_channels_last(self):
- """Test `call` function with `channels_last` data format."""
- with tf.device("/cpu:0"): # NHWC format
- input_shape = (8, 8, 128)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
-
- # Stride 1
- block = blocks.RevBlock(
- n_res=3,
- filters=128,
- strides=(1, 1),
- input_shape=input_shape,
- data_format="channels_last")
- y_tr, y_ev = block(x, training=True), block(x, training=False)
- self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 8, 8, 128))
- self.assertNotAllClose(y_tr, y_ev)
-
- # Stride of 2
- block = blocks.RevBlock(
- n_res=3,
- filters=128,
- strides=(2, 2),
- input_shape=input_shape,
- data_format="channels_last")
- y_tr, y_ev = block(x, training=True), block(x, training=False)
- self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 4, 4, 128))
- self.assertNotAllClose(y_tr, y_ev)
-
def _check_grad_angle(self, grads, grads_true, atol=1e0):
"""Check the angle between two list of vectors are all close."""
for g1, g2 in zip(grads, grads_true):
degree = compute_degree(g1, g2)
self.assertLessEqual(degree, atol)
- def test_backward_grads_and_vars_channels_first(self):
+ def test_backward_grads_channels_first(self):
"""Test `backward` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
@@ -190,6 +133,7 @@ class RevBlockTest(tf.test.TestCase):
data_shape = (16,) + input_shape
x = tf.random_normal(shape=data_shape, dtype=tf.float64)
dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1)
block = blocks.RevBlock(
n_res=3,
filters=128,
@@ -199,9 +143,14 @@ class RevBlockTest(tf.test.TestCase):
dtype=tf.float64)
with tf.GradientTape() as tape:
tape.watch(x)
- y = block(x, training=True)
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=1)
+ y1, y2 = block((x1, x2), training=True)
+ y = tf.concat((y1, y2), axis=1)
# Compute grads from reconstruction
- dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ (dx1, dx2), dw = block.backward_grads(
+ x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True)
+ dx = tf.concat((dx1, dx2), axis=1)
+ vars_ = block.trainable_variables
# Compute true grads
grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
@@ -213,6 +162,7 @@ class RevBlockTest(tf.test.TestCase):
# Stride 2
x = tf.random_normal(shape=data_shape, dtype=tf.float64)
dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1)
block = blocks.RevBlock(
n_res=3,
filters=128,
@@ -222,9 +172,14 @@ class RevBlockTest(tf.test.TestCase):
dtype=tf.float64)
with tf.GradientTape() as tape:
tape.watch(x)
- y = block(x, training=True)
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=1)
+ y1, y2 = block((x1, x2), training=True)
+ y = tf.concat((y1, y2), axis=1)
# Compute grads from reconstruction
- dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ (dx1, dx2), dw = block.backward_grads(
+ x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True)
+ dx = tf.concat((dx1, dx2), axis=1)
+ vars_ = block.trainable_variables
# Compute true grads
grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
@@ -236,16 +191,7 @@ class RevBlockTest(tf.test.TestCase):
class _ResidualTest(tf.test.TestCase):
- def test_call(self):
- """Test `call` function.
-
- Varying downsampling and data format options.
- """
-
- _validate_block_call_channels_first(blocks._Residual, self)
- _validate_block_call_channels_last(blocks._Residual, self)
-
- def test_backward_grads_and_vars_channels_first(self):
+ def test_backward_grads_channels_first(self):
"""Test `backward_grads` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
@@ -256,6 +202,7 @@ class _ResidualTest(tf.test.TestCase):
# Use double precision for testing
x_true = tf.random_normal(shape=data_shape, dtype=tf.float64)
dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1)
residual = blocks._Residual(
filters=128,
strides=(1, 1),
@@ -264,16 +211,19 @@ class _ResidualTest(tf.test.TestCase):
dtype=tf.float64)
with tf.GradientTape() as tape:
- x_true = tf.identity(x_true)
tape.watch(x_true)
- y = residual(x_true, training=True)
+ x1_true, x2_true = tf.split(x_true, num_or_size_splits=2, axis=1)
+ y1, y2 = residual((x1_true, x2_true), training=True)
+ y = tf.concat((y1, y2), axis=1)
# Gradients computed due to reversibility
- x, dx, dw, vars_ = residual.backward_grads_and_vars(
- y, dy=dy, training=True)
-
+ (x1, x2), (dx1, dx2), dw = residual.backward_grads(
+ y=(y1, y2), dy=(dy1, dy2), training=True)
+ x = tf.concat((x1, x2), axis=1)
+ dx = tf.concat((dx1, dx2), axis=1)
# True gradients computed by the tape
- grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy)
+ grads = tape.gradient(
+ y, [x_true] + residual.trainable_variables, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
self.assertAllClose(x_true, x)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
index 821a4878c1..29f1db0e03 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/config.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -82,7 +82,8 @@ def get_hparams_cifar_38():
config.num_train_images // config.tpu_batch_size)
config.add_hparam("tpu_epochs",
config.max_train_iter // config.tpu_iters_per_epoch)
-
+ config.add_hparam("tpu_eval_steps",
+ config.num_eval_images // config.tpu_eval_batch_size)
return config
@@ -162,7 +163,8 @@ def get_hparams_imagenet_56():
config.num_train_images // config.tpu_batch_size)
config.add_hparam("tpu_epochs",
config.max_train_iter // config.tpu_iters_per_epoch)
-
+ config.add_hparam("tpu_eval_steps",
+ config.num_eval_images // config.tpu_eval_batch_size)
return config
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index dcd4e1697f..b702e91f92 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.eager.python.examples.revnet import revnet
tfe = tf.contrib.eager
+def apply_gradients(optimizer, grads, vars_, global_step=None):
+ """Functional style apply_grads for `tfe.defun`."""
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+
+
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
tf.enable_eager_execution()
@@ -48,6 +53,11 @@ def main(_):
if FLAGS.use_defun:
model.call = tfe.defun(model.call)
+ model.compute_gradients = tfe.defun(model.compute_gradients)
+ model.get_moving_stats = tfe.defun(model.get_moving_stats)
+ model.restore_moving_stats = tfe.defun(model.restore_moving_stats)
+ global apply_gradients # pylint:disable=global-variable-undefined
+ apply_gradients = tfe.defun(apply_gradients)
if FLAGS.train_dir:
summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
@@ -197,9 +207,13 @@ def get_datasets(data_dir, config):
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ logits, saved_hiddens = model(inputs, training=True)
+ values = model.get_moving_stats()
+ grads, loss = model.compute_gradients(saved_hiddens, labels)
+ # Restore moving averages when executing eagerly to avoid updating twice
+ model.restore_moving_stats(values)
+ apply_gradients(
+ optimizer, grads, model.trainable_variables, global_step=global_step)
return logits, loss
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
index 4868f1931f..3a17eb30da 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
@@ -53,10 +53,11 @@ def model_fn(features, labels, mode, params):
global_step, config.lr_decay_steps, config.lr_list)
optimizer = tf.train.MomentumOptimizer(
learning_rate, momentum=config.momentum)
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
+ logits, saved_hidden = model(inputs, training=True)
+ grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
+ with tf.control_dependencies(model.get_updates_for(inputs)):
+ train_op = optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
else:
@@ -130,8 +131,7 @@ def get_input_fn(config, data_dir, split):
return input_fn
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
+def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# RevNet specific configuration
@@ -139,7 +139,7 @@ def main(argv):
# Estimator specific configuration
run_config = tf.estimator.RunConfig(
- model_dir=FLAGS.train_dir, # Directory for storing checkpoints
+ model_dir=FLAGS.model_dir, # Directory for storing checkpoints
tf_random_seed=config.seed,
save_summary_steps=config.log_every,
save_checkpoints_steps=config.log_every,
@@ -153,7 +153,7 @@ def main(argv):
# Construct estimator
revnet_estimator = tf.estimator.Estimator(
model_fn=model_fn,
- model_dir=FLAGS.train_dir,
+ model_dir=FLAGS.model_dir,
config=run_config,
params={"config": config})
@@ -173,14 +173,14 @@ def main(argv):
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
"image": inputs
})
- revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
+ revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn)
if __name__ == "__main__":
flags.DEFINE_string(
"data_dir", default=None, help="Directory to load tfrecords")
flags.DEFINE_string(
- "train_dir",
+ "model_dir",
default=None,
help="[Optional] Directory to store the training information")
flags.DEFINE_string(
@@ -197,4 +197,4 @@ if __name__ == "__main__":
help="[Optional] Architecture of network. "
"Other options include `revnet-110` and `revnet-164`")
FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
+ tf.app.run()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
index d809bcd287..f0aad9b110 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
@@ -47,7 +47,6 @@ def model_fn(features, labels, mode, params):
if isinstance(inputs, dict):
inputs = features["image"]
- FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name
config = params["config"]
model = revnet.RevNet(config=config)
@@ -61,14 +60,10 @@ def model_fn(features, labels, mode, params):
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
- # Define gradients
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
+ logits, saved_hidden = model(inputs, training=True)
+ grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
-
- names = [v.name for v in model.variables]
- tf.logging.warn("{}".format(names))
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.contrib.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
@@ -141,8 +136,7 @@ def get_input_fn(config, data_dir, split):
return input_fn
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
+def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# RevNet specific configuration
@@ -177,10 +171,7 @@ def main(argv):
train_batch_size=config.tpu_batch_size,
eval_batch_size=config.tpu_eval_batch_size,
config=run_config,
- params={
- "FLAGS": FLAGS,
- "config": config,
- })
+ params={"config": config})
# Construct input functions
train_input_fn = get_input_fn(
@@ -325,4 +316,4 @@ if __name__ == "__main__":
" possible (i.e. up to --train_steps, which evaluates the model only"
" after finishing the entire training regime)."))
FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
+ tf.app.run()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index b1cb312b74..1f2cb14972 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -24,7 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
@@ -45,6 +44,7 @@ class RevNet(tf.keras.Model):
self._init_block = blocks.InitBlock(config=self.config)
self._final_block = blocks.FinalBlock(config=self.config)
self._block_list = self._construct_intermediate_blocks()
+ self._moving_average_variables = []
def _construct_intermediate_blocks(self):
# Precompute input shape after initial block
@@ -128,126 +128,90 @@ class RevNet(tf.keras.Model):
return tf.reduce_mean(cross_ent)
- def compute_gradients(self, inputs, labels, training=True, l2_reg=True):
+ def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True):
"""Manually computes gradients.
- When eager execution is enabled, this method also SILENTLY updates the
- running averages of batch normalization when `training` is set to True.
+ This method silently updates the running averages of batch normalization.
Args:
- inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
+ saved_hidden: List of hidden states Tensors
labels: One-hot labels for classification
training: Use the mini-batch stats in batch norm if set to True
l2_reg: Apply l2 regularization
Returns:
- A tuple with the first entry being a list of all gradients, the second
- entry being a list of respective variables, the third being the logits,
- and the forth being the loss
+ A tuple with the first entry being a list of all gradients and the second
+ being the loss
"""
- # Run forward pass to record hidden states
- vars_and_vals = self.get_moving_stats()
- _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable
- if tf.executing_eagerly():
- # Restore moving averages when executing eagerly to avoid updating twice
- self.restore_moving_stats(vars_and_vals)
- else:
- # Fetch batch norm updates in graph mode
- updates = self.get_updates_for(inputs)
-
- grads_all = []
- vars_all = []
+ def _defunable_pop(l):
+ """Functional style list pop that works with `tfe.defun`."""
+ t, l = l[-1], l[:-1]
+ return t, l
- # Manually backprop through last block
+ # Backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
tape.watch(x)
- # Running stats updated here
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
-
grads_combined = tape.gradient(loss,
[x] + self._final_block.trainable_variables)
- dy, grads_ = grads_combined[0], grads_combined[1:]
- grads_all += grads_
- vars_all += self._final_block.trainable_variables
+ dy, final_grads = grads_combined[0], grads_combined[1:]
- # Manually backprop through intermediate blocks
+ # Backprop through intermediate blocks
+ intermediate_grads = []
for block in reversed(self._block_list):
- y = saved_hidden.pop()
+ y, saved_hidden = _defunable_pop(saved_hidden)
x = saved_hidden[-1]
- # Running stats updated here
- dy, grads, vars_ = block.backward_grads_and_vars(
- x, y, dy, training=training)
- grads_all += grads
- vars_all += vars_
-
- # Manually backprop through first block
- saved_hidden.pop()
- x = saved_hidden.pop()
- assert not saved_hidden # Cleared after backprop
+ dy, grads = block.backward_grads(x, y, dy, training=training)
+ intermediate_grads = grads + intermediate_grads
+ # Backprop through first block
+ _, saved_hidden = _defunable_pop(saved_hidden)
+ x, saved_hidden = _defunable_pop(saved_hidden)
+ assert not saved_hidden
with tf.GradientTape() as tape:
- # Running stats updated here
y = self._init_block(x, training=training)
-
- grads_all += tape.gradient(
+ init_grads = tape.gradient(
y, self._init_block.trainable_variables, output_gradients=dy)
- vars_all += self._init_block.trainable_variables
- # Apply weight decay
+ # Ordering match up with `model.trainable_variables`
+ grads_all = init_grads + final_grads + intermediate_grads
if l2_reg:
- grads_all = self._apply_weight_decay(grads_all, vars_all)
-
- if not tf.executing_eagerly():
- # Force updates to be executed before gradient computation in graph mode
- # This does nothing when the function is wrapped in defun
- with tf.control_dependencies(updates):
- grads_all[0] = tf.identity(grads_all[0])
+ grads_all = self._apply_weight_decay(grads_all)
- return grads_all, vars_all, logits, loss
+ return grads_all, loss
- def _apply_weight_decay(self, grads, vars_):
+ def _apply_weight_decay(self, grads):
"""Update gradients to reflect weight decay."""
- # Don't decay bias
return [
g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g
- for g, v in zip(grads, vars_)
+ for g, v in zip(grads, self.trainable_variables)
]
def get_moving_stats(self):
- """Get moving averages of batch normalization.
-
- This is needed to avoid updating the running average twice in one iteration.
-
- Returns:
- A dictionary mapping variables for batch normalization moving averages
- to their current values.
- """
- vars_and_vals = {}
-
- def _is_moving_var(v):
- n = v.name
- return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
+ """Get moving averages of batch normalization."""
+ device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
+ with tf.device(device):
+ return [v.read_value() for v in self.moving_average_variables]
+ def restore_moving_stats(self, values):
+ """Restore moving averages of batch normalization."""
device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
with tf.device(device):
- for v in filter(_is_moving_var, self.variables):
- vars_and_vals[v] = v.read_value()
+ for var_, val in zip(self.moving_average_variables, values):
+ var_.assign(val)
- return vars_and_vals
+ @property
+ def moving_average_variables(self):
+ """Get all variables that are batch norm moving averages."""
- def restore_moving_stats(self, vars_and_vals):
- """Restore moving averages of batch normalization.
+ def _is_moving_avg(v):
+ n = v.name
+ return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
- This is needed to avoid updating the running average twice in one iteration.
+ if not self._moving_average_variables:
+ self._moving_average_variables = filter(_is_moving_avg, self.variables)
- Args:
- vars_and_vals: The dictionary mapping variables to their previous values.
- """
- device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
- with tf.device(device):
- for var_, val in six.iteritems(vars_and_vals):
- # `assign` causes a copy to GPU (if variable is already on GPU)
- var_.assign(val)
+ return self._moving_average_variables
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 26b0847523..84b2ddf0de 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -31,9 +31,11 @@ tfe = tf.contrib.eager
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ logits, saved_hidden = model(inputs)
+ grads, loss = model.compute_gradients(
+ saved_hidden=saved_hidden, labels=labels)
+ optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return logits, loss
@@ -96,9 +98,10 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
- self.model(self.x, training=False) # Initialize model
- grads, vars_, logits, loss = self.model.compute_gradients(
- inputs=self.x, labels=self.t, training=True, l2_reg=True)
+ _, saved_hidden = self.model(self.x) # Initialize model
+ grads, loss = self.model.compute_gradients(
+ saved_hidden=saved_hidden, labels=self.t)
+ vars_ = self.model.trainable_variables
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -107,7 +110,7 @@ class RevNetTest(tf.test.TestCase):
# Compare against the true gradient computed by the tape
with tf.GradientTape() as tape:
- logits, _ = self.model(self.x, training=True)
+ logits, _ = self.model(self.x)
loss_true = self.model.compute_loss(logits=logits, labels=self.t)
grads_true = tape.gradient(loss_true, vars_)
self.assertAllClose(loss, loss_true)
@@ -122,7 +125,9 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients_defun(self):
"""Test `compute_gradients` function with defun."""
compute_gradients = tfe.defun(self.model.compute_gradients)
- grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True)
+ _, saved_hidden = self.model(self.x)
+ grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t)
+ vars_ = self.model.trainable_variables
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -146,10 +151,11 @@ class RevNetTest(tf.test.TestCase):
dtype=tf.int32)
global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
- grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True)
+ _, saved_hidden = model(x)
+ grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
train_op = optimizer.apply_gradients(
- zip(grads_all, vars_all), global_step=global_step)
+ zip(grads, model.trainable_variables), global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index ca6430253b..2f0ab616e4 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -34,6 +34,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@run
@@enable_eager_execution
+@@enable_remote_eager_execution
@@custom_gradient
@@ -114,6 +115,7 @@ from tensorflow.python.eager.execution_callbacks import inf_nan_callback
from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
from tensorflow.python.framework.ops import enable_eager_execution
+from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution
from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 43bfcffd79..7ed77bcce6 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -50,7 +50,8 @@ class _BoostedTreesEstimator(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesEstimator` instance.
Args:
@@ -89,13 +90,18 @@ class _BoostedTreesEstimator(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -129,7 +135,8 @@ def boosted_trees_classifier_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree classifier with in memory dataset.
Example:
@@ -208,6 +215,11 @@ def boosted_trees_classifier_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -228,7 +240,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -269,7 +281,8 @@ def boosted_trees_regressor_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree regressor with in memory dataset.
Example:
@@ -341,6 +354,11 @@ def boosted_trees_regressor_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -360,7 +378,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index 999c2aa5e2..b1581f3750 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -136,6 +136,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['average_loss'], 0.614642)
+ def testTrainAndEvaluateEstimatorWithPrePruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='pre')
+
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 2 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 3.83943)
+
+ def testTrainAndEvaluateEstimatorWithPostPruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='post')
+
+ # It will stop after 10 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(
+ est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.37652)
+
def testInferEstimator(self):
train_input_fn = _make_train_input_fn(is_classification=False)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -231,6 +274,31 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
+ def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=0.01)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 1 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11)
+
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testBinaryClassifierTrainInMemoryWithDataset(self):
train_input_fn = _make_train_input_fn_dataset(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
index f3d0f6b047..b0082f7e55 100644
--- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
+++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
@@ -46,6 +46,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
Example with `tf.estimator.DNNClassifier`:
**Step 1: Create and train DNNClassifier.**
+
```python
feature1 = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
@@ -66,13 +67,14 @@ class SavedModelEstimator(estimator_lib.Estimator):
**Step 2: Export classifier.**
First, build functions that specify the expected inputs.
+
```python
# During train and evaluation, both the features and labels should be defined.
supervised_input_receiver_fn = (
tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
- {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
- 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
- tf.placeholder(dtype=tf.float32, shape=[None])))
+ {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
+ 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
+ tf.placeholder(dtype=tf.float32, shape=[None])))
# During predict mode, expect to receive a `tf.Example` proto, so a parsing
# function is used.
@@ -83,6 +85,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
Next, export the model as a SavedModel. A timestamped directory will be
created (for example `/tmp/export_all/1234567890`).
+
```python
# Option 1: Save all modes (train, eval, predict)
export_dir = tf.contrib.estimator.export_all_saved_models(
@@ -93,10 +96,11 @@ class SavedModelEstimator(estimator_lib.Estimator):
# Option 2: Only export predict mode
export_dir = classifier.export_savedmodel(
- '/tmp/export_predict', serving_input_receiver_fn)
+ '/tmp/export_predict', serving_input_receiver_fn)
```
**Step 3: Create a SavedModelEstimator from the exported SavedModel.**
+
```python
est = tf.contrib.estimator.SavedModelEstimator(export_dir)
@@ -108,7 +112,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
est.train(input_fn=input_fn, steps=20)
def predict_input_fn():
- example = example_pb2.Example()
+ example = tf.train.Example()
example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
example.features.feature['feature2'].float_list.value.extend([1.])
return {'inputs':tf.constant([example.SerializeToString()])}
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 7e6cb72485..053d4e3e97 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -196,11 +196,16 @@ py_test(
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":losses_impl",
":namedtuples",
":tuple_losses",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
index dcc3f94c2d..221c70c38b 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
@@ -80,6 +80,9 @@ __all__ = [
'mutual_information_penalty',
'combine_adversarial_loss',
'cycle_consistency_loss',
+ 'stargan_generator_loss_wrapper',
+ 'stargan_discriminator_loss_wrapper',
+ 'stargan_gradient_penalty_wrapper'
]
@@ -277,3 +280,86 @@ def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False):
cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x,
cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y,
scope, add_summaries)
+
+
+def stargan_generator_loss_wrapper(loss_fn):
+ """Convert a generator loss function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking Discriminator's real/fake prediction for
+ generated data.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ return loss_fn(
+ stargan_model.discriminator_generated_data_source_predication, **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
+
+
+def stargan_discriminator_loss_wrapper(loss_fn):
+ """Convert a discriminator loss function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking Discriminator's real/fake prediction for
+ real data and generated data.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ return loss_fn(
+ stargan_model.discriminator_input_data_source_predication,
+ stargan_model.discriminator_generated_data_source_predication, **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
+
+
+def stargan_gradient_penalty_wrapper(loss_fn):
+ """Convert a gradient penalty function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking real_data, generated_data,
+ generator_inputs for Discriminator's condition (i.e. number of domains),
+ discriminator_fn, and discriminator_scope.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1]
+ return loss_fn(
+ real_data=stargan_model.input_data,
+ generated_data=stargan_model.generated_data,
+ generator_inputs=num_domains,
+ discriminator_fn=stargan_model.discriminator_fn,
+ discriminator_scope=stargan_model.discriminator_scope,
+ **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index aa1ef11172..a559bbfa11 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -22,10 +22,15 @@ import collections
import numpy as np
+from tensorflow.contrib import layers
from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -129,6 +134,9 @@ manual_tests = [
'mutual_information_penalty',
'wasserstein_gradient_penalty',
'cycle_consistency_loss',
+ 'stargan_generator_loss_wrapper',
+ 'stargan_discriminator_loss_wrapper',
+ 'stargan_gradient_penalty_wrapper'
]
discriminator_keyword_args = {
@@ -175,6 +183,112 @@ class CycleConsistencyLossTest(test.TestCase):
self.assertNear(5.0, loss.eval(), 1e-5)
+class StarGANLossWrapperTest(test.TestCase):
+
+ def setUp(self):
+
+ super(StarGANLossWrapperTest, self).setUp()
+
+ self.input_data = array_ops.ones([1, 2, 2, 3])
+ self.input_data_domain_label = constant_op.constant([[0, 1]])
+ self.generated_data = array_ops.ones([1, 2, 2, 3])
+ self.discriminator_input_data_source_predication = array_ops.ones([1])
+ self.discriminator_generated_data_source_predication = array_ops.ones([1])
+
+ def _discriminator_fn(inputs, num_domains):
+ """Differentiable dummy discriminator for StarGAN."""
+ hidden = layers.flatten(inputs)
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+ output_cls = layers.fully_connected(
+ inputs=hidden,
+ num_outputs=num_domains,
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=None)
+ return output_src, output_cls
+
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+
+ self.model = namedtuples.StarGANModel(
+ input_data=self.input_data,
+ input_data_domain_label=self.input_data_domain_label,
+ generated_data=self.generated_data,
+ generated_data_domain_target=None,
+ reconstructed_data=None,
+ discriminator_input_data_source_predication=self.
+ discriminator_input_data_source_predication,
+ discriminator_generated_data_source_predication=self.
+ discriminator_generated_data_source_predication,
+ discriminator_input_data_domain_predication=None,
+ discriminator_generated_data_domain_predication=None,
+ generator_variables=None,
+ generator_scope=None,
+ generator_fn=None,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=_discriminator_fn)
+
+ self.discriminator_fn = _discriminator_fn
+ self.discriminator_scope = dis_scope
+
+ def test_stargan_generator_loss_wrapper(self):
+ """Test StarGAN generator loss wrapper."""
+ loss_fn = tfgan_losses_impl.wasserstein_generator_loss
+ wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ self.discriminator_generated_data_source_predication)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+ def test_stargan_discriminator_loss_wrapper(self):
+ """Test StarGAN discriminator loss wrapper."""
+ loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss
+ wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ self.discriminator_generated_data_source_predication,
+ self.discriminator_generated_data_source_predication)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+ def test_stargan_gradient_penalty_wrapper(self):
+ """Test StaGAN gradient penalty wrapper.
+
+ Notes:
+ The random interpolates are handled by given setting the reconstruction to
+ be the same as the input.
+
+ """
+ loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty
+ wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ real_data=self.input_data,
+ generated_data=self.generated_data,
+ generator_inputs=self.input_data_domain_label.shape.as_list()[-1],
+ discriminator_fn=self.discriminator_fn,
+ discriminator_scope=self.discriminator_scope)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+
if __name__ == '__main__':
for loss_name in tfgan_losses.__all__:
if loss_name in manual_tests: continue
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index df603d1f18..03f52d214b 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -34,6 +34,7 @@ from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
from tensorflow.python.framework import dtypes
@@ -41,14 +42,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
+
__all__ = [
'gan_model',
'infogan_model',
@@ -751,6 +755,130 @@ def cyclegan_loss(
return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
+def stargan_loss(
+ model,
+ generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_generator_loss),
+ discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_discriminator_loss),
+ gradient_penalty_weight=10.0,
+ gradient_penalty_epsilon=1e-10,
+ gradient_penalty_target=1.0,
+ gradient_penalty_one_sided=False,
+ reconstruction_loss_fn=losses.absolute_difference,
+ reconstruction_loss_weight=10.0,
+ classification_loss_fn=losses.softmax_cross_entropy,
+ classification_loss_weight=1.0,
+ classification_one_hot=True,
+ add_summaries=True):
+ """StarGAN Loss.
+
+ The four major part can be found here: http://screen/tMRMBAohDYG.
+
+ Args:
+ model: (StarGAN) Model output of the stargan_model() function call.
+ generator_loss_fn: The loss function on the generator. Takes a
+ `StarGANModel` named tuple.
+ discriminator_loss_fn: The loss function on the discriminator. Takes a
+ `StarGANModel` namedtuple.
+ gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
+ the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
+ turn off gradient penalty.
+ gradient_penalty_epsilon: (float) A small positive number added for
+ numerical stability when computing the gradient norm.
+ gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
+ gradient norm. Defaults to 1.0.
+ gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
+ https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
+ reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
+ and the function must conform to the `tf.losses` API.
+ reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
+ classification_loss_fn: The loss function on the discriminator's ability to
+ classify domain of the input. Default to one-hot softmax cross entropy
+ loss, and the function must conform to the `tf.losses` API.
+ classification_loss_weight: (float) Classification loss weight. Default to
+ 1.0.
+ classification_one_hot: (bool) If the label is one hot representation.
+ Default to True. If False, classification classification_loss_fn need to
+ be sigmoid cross entropy loss instead.
+ add_summaries: (bool) Add the loss to the summary
+
+ Returns:
+ GANLoss namedtuple where we have generator loss and discriminator loss.
+
+ Raises:
+ ValueError: If input StarGANModel.input_data_domain_label does not have rank
+ 2, or dimension 2 is not defined.
+ """
+
+ def _classification_loss_helper(true_labels, predict_logits, scope_name):
+ """Classification Loss Function Helper.
+
+ Args:
+ true_labels: Tensor of shape [batch_size, num_domains] representing the
+ label where each row is an one-hot vector.
+ predict_logits: Tensor of shape [batch_size, num_domains] representing the
+ predicted label logit, which is UNSCALED output from the NN.
+ scope_name: (string) Name scope of the loss component.
+
+ Returns:
+ Single scalar tensor representing the classification loss.
+ """
+
+ with ops.name_scope(scope_name, values=(true_labels, predict_logits)):
+
+ loss = classification_loss_fn(
+ onehot_labels=true_labels, logits=predict_logits)
+
+ if not classification_one_hot:
+ loss = math_ops.reduce_sum(loss, axis=1)
+ loss = math_ops.reduce_mean(loss)
+
+ if add_summaries:
+ summary.scalar(scope_name, loss)
+
+ return loss
+
+ # Check input shape.
+ model.input_data_domain_label.shape.assert_has_rank(2)
+ model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ # Adversarial Loss.
+ generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
+ discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
+
+ # Gradient Penalty.
+ if _use_aux_loss(gradient_penalty_weight):
+ gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper(
+ tfgan_losses_impl.wasserstein_gradient_penalty)
+ discriminator_loss += gradient_penalty_fn(
+ model,
+ epsilon=gradient_penalty_epsilon,
+ target=gradient_penalty_target,
+ one_sided=gradient_penalty_one_sided,
+ add_summaries=add_summaries) * gradient_penalty_weight
+
+ # Reconstruction Loss.
+ reconstruction_loss = reconstruction_loss_fn(model.input_data,
+ model.reconstructed_data)
+ generator_loss += reconstruction_loss * reconstruction_loss_weight
+ if add_summaries:
+ summary.scalar('reconstruction_loss', reconstruction_loss)
+
+ # Classification Loss.
+ generator_loss += _classification_loss_helper(
+ true_labels=model.generated_data_domain_target,
+ predict_logits=model.discriminator_generated_data_domain_predication,
+ scope_name='generator_classification_loss') * classification_loss_weight
+ discriminator_loss += _classification_loss_helper(
+ true_labels=model.input_data_domain_label,
+ predict_logits=model.discriminator_input_data_domain_predication,
+ scope_name='discriminator_classification_loss'
+ ) * classification_loss_weight
+
+ return namedtuples.GANLoss(generator_loss, discriminator_loss)
+
+
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index df8e0041a9..58f348034f 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -666,6 +666,27 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
self.assertTrue(np.isscalar(loss_y2x_dis_np))
@parameterized.named_parameters(
+ ('notcallable', create_stargan_model),
+ ('callable', create_callable_stargan_model),
+ )
+ def test_stargan(self, create_gan_model_fn):
+
+ model = create_gan_model_fn()
+ model_loss = train.stargan_loss(model)
+
+ self.assertIsInstance(model_loss, namedtuples.GANLoss)
+
+ with self.test_session() as sess:
+
+ sess.run(variables.global_variables_initializer())
+
+ gen_loss, disc_loss = sess.run(
+ [model_loss.generator_loss, model_loss.discriminator_loss])
+
+ self.assertTrue(np.isscalar(gen_loss))
+ self.assertTrue(np.isscalar(disc_loss))
+
+ @parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index c42622ff02..ef6c14f085 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -99,7 +99,9 @@ FileCopyAllocation::FileCopyAllocation(const char* filename,
filename);
return;
}
- copied_buffer_ = std::move(buffer);
+ // Versions of GCC before 6.2.0 don't support std::move from non-const
+ // char[] to const char[] unique_ptrs.
+ copied_buffer_.reset(const_cast<char const*>(buffer.release()));
}
FileCopyAllocation::~FileCopyAllocation() {}
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index a8a49784c6..422584c0ea 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -2,8 +2,8 @@
load(
"//tensorflow:tensorflow.bzl",
- "tf_cc_test",
"tf_cc_shared_object",
+ "tf_cc_test",
)
def tflite_copts():
@@ -27,6 +27,9 @@ def tflite_copts():
str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
+ str(Label("//tensorflow:windows")): [
+ "/DTF_COMPILE_LIBRARY",
+ ],
"//conditions:default": [],
}) + select({
str(Label("//tensorflow:with_default_optimizations")): [],
@@ -53,6 +56,7 @@ def tflite_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
+ "//tensorflow:darwin": [],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
@@ -74,6 +78,7 @@ def tflite_jni_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
+ "//tensorflow:darwin": [],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
@@ -122,19 +127,21 @@ def tflite_jni_binary(
linkopts = linkopts,
)
-def tflite_cc_shared_object(name,
- copts=tflite_copts(),
- linkopts=[],
- linkstatic=1,
- deps=[]):
- """Builds a shared object for TFLite."""
- tf_cc_shared_object(
- name=name,
- copts=copts,
- linkstatic=linkstatic,
- linkopts=linkopts + tflite_jni_linkopts(),
- framework_so=[],
- deps=deps)
+def tflite_cc_shared_object(
+ name,
+ copts = tflite_copts(),
+ linkopts = [],
+ linkstatic = 1,
+ deps = []):
+ """Builds a shared object for TFLite."""
+ tf_cc_shared_object(
+ name = name,
+ copts = copts,
+ linkstatic = linkstatic,
+ linkopts = linkopts + tflite_jni_linkopts(),
+ framework_so = [],
+ deps = deps,
+ )
def tf_to_tflite(name, src, options, out):
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
@@ -240,6 +247,7 @@ def generated_test_models():
"local_response_norm",
"log_softmax",
"log",
+ "logical_or",
"lstm",
"max_pool",
"maximum",
@@ -248,6 +256,7 @@ def generated_test_models():
"mul",
"neg",
"not_equal",
+ "one_hot",
"pack",
"pad",
"padv2",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index fd16aa1063..70178b2faa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -282,6 +282,10 @@ typedef struct {
int axis;
} TfLitePackParams;
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 1ae73b9738..0b6568fd2f 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -110,6 +110,7 @@ typedef enum {
kTfLiteBuiltinReduceMax = 82,
kTfLiteBuiltinPack = 83,
kTfLiteBuiltinLogicalOr = 84,
+ kTfLiteBuiltinOneHot = 85,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index cbfce12d7e..5bc20106d3 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -29,6 +29,9 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#if defined(_MSC_VER)
+#include <complex.h>
+#endif
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
@@ -180,7 +183,11 @@ typedef union {
uint8_t* uint8;
bool* b;
int16_t* i16;
+#if defined(_MSC_VER)
+ _Fcomplex* c64;
+#else
_Complex float* c64;
+#endif
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 03a4b7bf1d..a28707382e 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -68,6 +68,43 @@ cc_test(
)
cc_library(
+ name = "kernel",
+ srcs = ["kernel.cc"],
+ hdrs = ["kernel.h"],
+ deps = [
+ ":delegate_data",
+ ":util",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/common_runtime/eager:context",
+ "//tensorflow/core/common_runtime/eager:execute",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
+ "@flatbuffers",
+ ],
+)
+
+cc_test(
+ name = "kernel_test",
+ size = "small",
+ srcs = ["kernel_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":delegate_data",
+ ":kernel",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
index 1d6453f498..e5a19c3997 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
@@ -91,6 +91,10 @@ void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
for (int i = 0; i < num_dims; ++i) {
shape.AddDim(tensor->dims->data[i]);
}
+ // TODO(ahentz): we assume this is a new tensor and allocate a new buffer
+ // for it. This is not always the best approach. For example, this might
+ // be a reallocation after resizing tensors. In that case we would be
+ // preferable to somehow reuse the buffer.
auto* buf = new TfLiteTensorBuffer(tensor);
tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor(
GetTensorFlowDataType(tensor->type), shape, buf);
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
index 29687694bd..0fd5c976f8 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
@@ -23,7 +23,8 @@ tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
- tensorflow::SessionOptions(), "/device:cpu:*", &devices));
+ tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
+ &devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
new file mode 100644
index 0000000000..1727981807
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+
+#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/builtin_ops.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+// Note: this is part of TF Lite's Eager delegation code which is to be
+// completed soon.
+
+// This is the TF Lite op that is created by the eager delegate to handle
+// execution of a supported subgraph. The usual flow is that the delegate
+// informs the interpreter of supported nodes in a graph, and each supported
+// subgraph is replaced with one instance of this kernel.
+//
+// The kernel is initialized with TfLiteDelegateParams from which we retrieve
+// the global EagerContext and BufferMap, as well as a list of inputs and
+// outputs to the subgraph. Those are used to build the OpData, with a list of
+// TensorFlow Ops that should be executed in order (which we call an OpNode).
+//
+// For each node included in the subgraph, we query the interpreter and
+// retrieve the associated NodeDef, which is then used to configure the
+// corresponding TensorFlow/Eager Op.
+
+namespace tflite {
+namespace eager {
+namespace kernel {
+
+// Controls the lifetime of tensor handles in a vector.
+class VectorOfHandles {
+ public:
+ explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {}
+
+ ~VectorOfHandles() {
+ for (auto* handle : vector_) {
+ if (handle) handle->Unref();
+ }
+ }
+
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() {
+ return &vector_;
+ }
+
+ tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; }
+
+ private:
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
+};
+
+// Executes the TensorFlow op given by 'op_name', with the attributes specified
+// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
+tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
+ BufferMap* buffer_map, const string& op_name,
+ const tensorflow::NodeDef& nodedef,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ const tensorflow::AttrTypeMap* attr_types;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types),
+ " (while processing attributes of '", op_name, "')");
+
+ tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types);
+ for (const auto& attr : nodedef.attr()) {
+ op.MutableAttrs()->Set(attr.first, attr.second);
+ }
+
+ for (int input_index : inputs) {
+ if (!buffer_map->HasTensor(input_index)) {
+ return tensorflow::errors::Internal(
+ "Cannot read from invalid tensor index ", input_index);
+ }
+ auto* handle = new tensorflow::TensorHandle(
+ buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
+ op.AddInput(handle);
+ handle->Unref();
+ }
+
+ int num_retvals = outputs.size();
+ VectorOfHandles retvals(num_retvals);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EagerExecute(&op, retvals.GetVector(), &num_retvals),
+ " (while executing '", op_name, "' via Eager)");
+
+ if (num_retvals != outputs.size()) {
+ return tensorflow::errors::Internal(
+ "Unexpected number of outputs from EagerExecute");
+ }
+
+ for (int i = 0; i < num_retvals; ++i) {
+ const tensorflow::Tensor* tensor = nullptr;
+ TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor));
+ buffer_map->SetFromTensorFlow(outputs[i], *tensor);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+// A single node within the larger 'op'. Note that this kernel executes many
+// TensorFlow ops within a single TF Lite op.
+struct OpNode {
+ // The name of the TensorFlow op to execute.
+ string name;
+ // The corresponding NodeDef, containing the attributes for the op.
+ tensorflow::NodeDef nodedef;
+ // List of inputs, as TF Lite tensor indices.
+ std::vector<int> inputs;
+ // List of outputs, as TF Lite tensor indices.
+ std::vector<int> outputs;
+};
+
+// The Larger 'op', which contains all the nodes in a supported subgraph.
+struct OpData {
+ tensorflow::EagerContext* eager_context;
+ BufferMap* buffer_map;
+ std::vector<OpNode> nodes;
+ std::vector<int> subgraph_inputs;
+ std::vector<int> subgraph_outputs;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+ CHECK(params);
+ CHECK(params->delegate);
+ CHECK(params->delegate->data_);
+ op_data->eager_context =
+ reinterpret_cast<DelegateData*>(params->delegate->data_)
+ ->GetEagerContext();
+ op_data->buffer_map =
+ reinterpret_cast<DelegateData*>(params->delegate->data_)->GetBufferMap();
+
+ CHECK(params->output_tensors);
+ for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
+ op_data->subgraph_outputs.push_back(tensor_index);
+ }
+
+ CHECK(params->input_tensors);
+ for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
+ op_data->subgraph_inputs.push_back(tensor_index);
+ }
+
+ CHECK(params->nodes_to_replace);
+ for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+
+ op_data->nodes.push_back(OpNode());
+ OpNode& node_data = op_data->nodes.back();
+
+ node_data.name = "";
+ if (node->custom_initial_data) {
+ // The flexbuffer contains a vector where the first elements is the
+ // op name and the second is a serialized NodeDef.
+ const flexbuffers::Vector& v =
+ flexbuffers::GetRoot(
+ reinterpret_cast<const uint8_t*>(node->custom_initial_data),
+ node->custom_initial_data_size)
+ .AsVector();
+
+ node_data.name = v[0].AsString().str();
+ if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) {
+ // We will just leave the nodedef empty and error out in Eval().
+ node_data.nodedef.Clear();
+ }
+ }
+
+ for (auto input_index : TfLiteIntArrayView(node->inputs)) {
+ node_data.inputs.push_back(input_index);
+ }
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ node_data.outputs.push_back(output_index);
+ }
+ }
+
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_MSG(
+ context, op_data->eager_context != nullptr,
+ "Failed to initialize eager context. This often happens when a CPU "
+ "device has not been registered, presumably because some symbols from "
+ "tensorflow/core:core_cpu_impl were not linked into the binary.");
+
+ // Whenever we find a constant tensor, insert it in the buffer map.
+ BufferMap* buffer_map = op_data->buffer_map;
+ for (auto tensor_index : op_data->subgraph_inputs) {
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ if (IsConstantTensor(tensor)) {
+ if (!buffer_map->HasTensor(tensor_index)) {
+ buffer_map->SetFromTfLite(tensor_index, tensor);
+ }
+ }
+ }
+
+ // All output tensors are allocated by TensorFlow/Eager, so we
+ // mark them as kTfLiteDynamic.
+ for (auto tensor_index : op_data->subgraph_outputs) {
+ SetTensorToDynamic(&context->tensors[tensor_index]);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ BufferMap* buffer_map = op_data->buffer_map;
+ tensorflow::EagerContext* eager_context = op_data->eager_context;
+
+ // Insert a tensor in the buffer map for all inputs that are not constant.
+ // Constants were handled in Prepare() already.
+ for (auto tensor_index : op_data->subgraph_inputs) {
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ if (!IsConstantTensor(tensor)) {
+ buffer_map->SetFromTfLite(tensor_index, tensor);
+ }
+ }
+
+ // Execute the TensorFlow Ops sequentially.
+ for (const auto& node_data : op_data->nodes) {
+ if (node_data.nodedef.op().empty()) {
+ context->ReportError(context, "Invalid NodeDef in Eager op '%s'",
+ node_data.name.c_str());
+ return kTfLiteError;
+ }
+ auto status =
+ ExecuteEagerOp(eager_context, buffer_map, node_data.name,
+ node_data.nodedef, node_data.inputs, node_data.outputs);
+ TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
+ }
+
+ for (auto tensor_index : op_data->subgraph_outputs) {
+ if (!buffer_map->HasTensor(tensor_index)) {
+ context->ReportError(context, "Cannot write to invalid tensor index %d",
+ tensor_index);
+ return kTfLiteError;
+ }
+
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ TF_LITE_ENSURE_OK(
+ context,
+ CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+ tensor->buffer_handle = tensor_index;
+ tensor->data_is_stale = true;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace kernel
+
+TfLiteRegistration GetKernel() {
+ TfLiteRegistration registration{&kernel::Init, &kernel::Free,
+ &kernel::Prepare, &kernel::Eval,
+ nullptr, kTfLiteBuiltinDelegate};
+ return registration;
+}
+
+} // namespace eager
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
new file mode 100644
index 0000000000..100672c82d
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace eager {
+
+// Return the registration object used to initialize and execute ops that will
+// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
+// the eager delegate to handle execution of a supported subgraph. The usual
+// flow is that the delegate informs the interpreter of supported nodes in a
+// graph, and each supported subgraph is replaced with one instance of this
+// kernel.
+TfLiteRegistration GetKernel();
+
+} // namespace eager
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
new file mode 100644
index 0000000000..7d9dddef93
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
@@ -0,0 +1,351 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace eager {
+namespace {
+
+using tensorflow::protobuf::TextFormat;
+using ::testing::ContainsRegex;
+using ::testing::ElementsAre;
+
+// We will use these are custom_names, so they need to be static.
+static const char kIdentity[] = "Identity";
+static const char kUnpack[] = "Unpack";
+static const char kAdd[] = "Add";
+static const char kMul[] = "Mul";
+
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
+ const std::vector<int>& supported_nodes) {
+ TfLiteIntArray* size_and_nodes =
+ ConvertVectorToTfLiteIntArray(supported_nodes);
+ TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels(
+ context, eager::GetKernel(), size_and_nodes, delegate));
+ TfLiteIntArrayFree(size_and_nodes);
+ return kTfLiteOk;
+}
+
+class KernelTest : public ::testing::Test {
+ public:
+ KernelTest() {
+ CHECK(DelegateData::Create(&delegate_data_).ok());
+ interpreter_.reset(new Interpreter(&error_reporter_));
+ }
+
+ bool Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
+
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ float* v = interpreter_->typed_tensor<float>(tensor_index);
+ for (float f : values) {
+ *v++ = f;
+ }
+ }
+
+ std::vector<float> GetValues(int tensor_index) {
+ TfLiteTensor* o = interpreter_->tensor(tensor_index);
+ return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
+ }
+
+ void SetShape(int tensor_index, const std::vector<int>& values) {
+ ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
+ ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+ }
+
+ std::vector<int> GetShape(int tensor_index) {
+ std::vector<int> result;
+ auto* dims = interpreter_->tensor(tensor_index)->dims;
+ for (int i = 0; i < dims->size; ++i) {
+ result.push_back(dims->data[i]);
+ }
+ return result;
+ }
+
+ template <typename T>
+ void ConfigureDelegate(T prepare_function) {
+ delegate_.data_ = delegate_data_.get();
+ delegate_.FreeBufferHandle = nullptr;
+ delegate_.Prepare = prepare_function;
+ delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size) {
+ auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
+ tensorflow::StringPiece values =
+ delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data();
+ memcpy(data, values.data(), values.size());
+ return kTfLiteOk;
+ };
+ CHECK(interpreter_->ModifyGraphWithDelegate(
+ &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk);
+ }
+
+ void AddOp(const char* name, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ auto attr = [](const string& key, const string& value) {
+ return " attr{ key: '" + key + "' value {" + value + "}}";
+ };
+
+ string attributes;
+ if (name == string(kUnpack)) {
+ attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
+ attr("axis", "i: 0");
+ } else if (name == string(kIdentity)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ } else if (name == string(kAdd)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ } else if (name == string(kMul)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ }
+ AddTfOp(name, attributes, inputs, outputs);
+ }
+
+ void AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ interpreter_->AddTensors(num_tensors);
+ for (int i = 0; i < num_tensors; ++i) {
+ TfLiteQuantizationParams quant;
+ CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32,
+ /*name=*/"",
+ /*dims=*/{3}, quant),
+ kTfLiteOk);
+ }
+
+ CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
+ CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
+ }
+
+ const TestErrorReporter& error_reporter() const { return error_reporter_; }
+
+ void AddTfLiteOp(const char* name, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ CHECK_EQ(string(name), kMul); // can only add MUL
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_MUL;
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* i1 = &context->tensors[node->inputs->data[1]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ for (int i = 0; i < o->bytes / sizeof(float); ++i) {
+ o->data.f[i] = i0->data.f[i] * i1->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+
+ CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
+ nullptr, &reg),
+ kTfLiteOk);
+ }
+
+ private:
+ void AddTfOp(const char* name, const string& nodedef_str,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_CUSTOM;
+ reg.custom_name = name;
+
+ tensorflow::NodeDef nodedef;
+ CHECK(TextFormat::ParseFromString(nodedef_str + " op: '" + name + "'",
+ &nodedef));
+ string serialized_nodedef;
+ CHECK(nodedef.SerializeToString(&serialized_nodedef));
+ flexbuffers::Builder fbb;
+ fbb.Vector([&]() {
+ fbb.String(nodedef.op());
+ fbb.String(serialized_nodedef);
+ });
+ fbb.Finish();
+
+ flexbuffers_.push_back(fbb.GetBuffer());
+ auto& buffer = flexbuffers_.back();
+ CHECK_EQ(interpreter_->AddNodeWithParameters(
+ inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
+ buffer.size(), nullptr, &reg),
+ kTfLiteOk);
+ }
+
+ std::unique_ptr<Interpreter> interpreter_;
+ std::unique_ptr<DelegateData> delegate_data_;
+ TfLiteDelegate delegate_;
+ std::vector<std::vector<uint8_t>> flexbuffers_;
+ TestErrorReporter error_reporter_;
+};
+
+TEST_F(KernelTest, FullGraph) {
+ // Define the graph.
+ AddTensors(9, {0, 3}, {8});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kUnpack, {3}, {4, 5});
+ AddOp(kAdd, {1, 4}, {6});
+ AddOp(kAdd, {2, 5}, {7});
+ AddOp(kMul, {6, 7}, {8});
+
+ // Apply Delegate.
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 3, 4});
+ });
+
+ // Define inputs.
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(KernelTest, BadTensorFlowOp) {
+ AddTensors(2, {0}, {1});
+ AddOp("NonExistentOp", {0}, {1});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("while processing attributes of 'NonExistentOp'"));
+}
+
+TEST_F(KernelTest, BadNumberOfOutputs) {
+ AddTensors(3, {0}, {1, 2});
+ AddOp(kIdentity, {0}, {1, 2});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("Unexpected number of outputs"));
+}
+
+TEST_F(KernelTest, IncompatibleNodeDef) {
+ AddTensors(2, {0}, {1});
+
+ // Cast is a TF op, but we don't add the proper nodedef to it in AddOp.
+ AddOp("Cast", {0}, {1});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("while executing 'Cast' via Eager"));
+}
+
+TEST_F(KernelTest, WrongSetOfNodes) {
+ AddTensors(4, {0}, {3});
+ AddOp(kUnpack, {0}, {1, 2});
+ AddTfLiteOp(kMul, {1, 2}, {3});
+
+ // Specify that kMul (#1) is supported when it actually isn't.
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("Invalid NodeDef in Eager op"));
+}
+
+TEST_F(KernelTest, MixedGraph) {
+ AddTensors(9, {0, 3}, {8});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kUnpack, {3}, {4, 5});
+ AddOp(kAdd, {1, 4}, {6});
+ AddOp(kAdd, {2, 5}, {7});
+ AddTfLiteOp(kMul, {6, 7}, {8});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 3});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(KernelTest, SplitGraph) {
+ AddTensors(10, {0}, {9});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kAdd, {1, 2}, {3});
+ AddOp(kUnpack, {3}, {4, 5});
+
+ AddTfLiteOp(kMul, {4, 5}, {6});
+
+ AddOp(kUnpack, {6}, {7, 8});
+ AddOp(kAdd, {7, 8}, {9});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 4, 5});
+ });
+
+ SetShape(0, {2, 2, 2, 1});
+ SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(9), ElementsAre(1));
+ ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
+}
+
+} // namespace
+} // namespace eager
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index b09bb9ea10..50f8da66d0 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -5,6 +5,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow/contrib/lite:build_def.bzl",
"tflite_cc_shared_object",
+ "tflite_copts",
"tflite_jni_binary",
)
@@ -30,16 +31,11 @@ tflite_cc_shared_object(
],
)
-tflite_jni_binary(
- name = "libtensorflowlite_c_jni.so",
- linkscript = ":version_script.lds",
- deps = [":c_api"],
-)
-
cc_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
+ copts = tflite_copts(),
deps = [
"//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite:framework",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index add4c6813d..9d29e8b3e0 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -27,6 +27,8 @@ struct _TFL_Interpreter {
std::unique_ptr<tflite::Interpreter> impl;
};
+// LINT.IfChange
+
TFL_Interpreter* TFL_NewInterpreter(const void* model_data,
int32_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
@@ -113,6 +115,8 @@ TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data,
return kTfLiteOk;
}
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore
new file mode 100644
index 0000000000..c72a5cae9e
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore
@@ -0,0 +1,13 @@
+# Unity generated
+Builds/
+Temp/
+Library/
+obj/
+# Visual Studio / MonoDevelop generated
+*.csproj
+*.unityproj
+*.sln
+*.suo
+*.userprefs
+# OS generated
+.DS_Store
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta
new file mode 100644
index 0000000000..ed9337b53e
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 71d1b4219b1da4aeaa1cebbec324fc81
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta
new file mode 100644
index 0000000000..edcce00939
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: d948aead14abd4c88947c9886d16f774
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta
new file mode 100644
index 0000000000..36b35516f0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: b810b85b794fa48fd93100acf5525e1f
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta
new file mode 100644
index 0000000000..d4133da49a
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 154f4201e2e454d4696fa5834eaa3ad3
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity
new file mode 100644
index 0000000000..9397d8f27a
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity
@@ -0,0 +1,242 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!29 &1
+OcclusionCullingSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 2
+ m_OcclusionBakeSettings:
+ smallestOccluder: 5
+ smallestHole: 0.25
+ backfaceThreshold: 100
+ m_SceneGUID: 00000000000000000000000000000000
+ m_OcclusionCullingData: {fileID: 0}
+--- !u!104 &2
+RenderSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 8
+ m_Fog: 0
+ m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1}
+ m_FogMode: 3
+ m_FogDensity: 0.01
+ m_LinearFogStart: 0
+ m_LinearFogEnd: 300
+ m_AmbientSkyColor: {r: 0.212, g: 0.227, b: 0.259, a: 1}
+ m_AmbientEquatorColor: {r: 0.114, g: 0.125, b: 0.133, a: 1}
+ m_AmbientGroundColor: {r: 0.047, g: 0.043, b: 0.035, a: 1}
+ m_AmbientIntensity: 1
+ m_AmbientMode: 3
+ m_SubtractiveShadowColor: {r: 0.42, g: 0.478, b: 0.627, a: 1}
+ m_SkyboxMaterial: {fileID: 0}
+ m_HaloStrength: 0.5
+ m_FlareStrength: 1
+ m_FlareFadeSpeed: 3
+ m_HaloTexture: {fileID: 0}
+ m_SpotCookie: {fileID: 10001, guid: 0000000000000000e000000000000000, type: 0}
+ m_DefaultReflectionMode: 0
+ m_DefaultReflectionResolution: 128
+ m_ReflectionBounces: 1
+ m_ReflectionIntensity: 1
+ m_CustomReflection: {fileID: 0}
+ m_Sun: {fileID: 0}
+ m_IndirectSpecularColor: {r: 0, g: 0, b: 0, a: 1}
+--- !u!157 &3
+LightmapSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 11
+ m_GIWorkflowMode: 1
+ m_GISettings:
+ serializedVersion: 2
+ m_BounceScale: 1
+ m_IndirectOutputScale: 1
+ m_AlbedoBoost: 1
+ m_TemporalCoherenceThreshold: 1
+ m_EnvironmentLightingMode: 0
+ m_EnableBakedLightmaps: 0
+ m_EnableRealtimeLightmaps: 0
+ m_LightmapEditorSettings:
+ serializedVersion: 9
+ m_Resolution: 2
+ m_BakeResolution: 40
+ m_TextureWidth: 1024
+ m_TextureHeight: 1024
+ m_AO: 0
+ m_AOMaxDistance: 1
+ m_CompAOExponent: 1
+ m_CompAOExponentDirect: 0
+ m_Padding: 2
+ m_LightmapParameters: {fileID: 0}
+ m_LightmapsBakeMode: 1
+ m_TextureCompression: 1
+ m_FinalGather: 0
+ m_FinalGatherFiltering: 1
+ m_FinalGatherRayCount: 256
+ m_ReflectionCompression: 2
+ m_MixedBakeMode: 2
+ m_BakeBackend: 0
+ m_PVRSampling: 1
+ m_PVRDirectSampleCount: 32
+ m_PVRSampleCount: 500
+ m_PVRBounces: 2
+ m_PVRFilterTypeDirect: 0
+ m_PVRFilterTypeIndirect: 0
+ m_PVRFilterTypeAO: 0
+ m_PVRFilteringMode: 1
+ m_PVRCulling: 1
+ m_PVRFilteringGaussRadiusDirect: 1
+ m_PVRFilteringGaussRadiusIndirect: 5
+ m_PVRFilteringGaussRadiusAO: 2
+ m_PVRFilteringAtrousPositionSigmaDirect: 0.5
+ m_PVRFilteringAtrousPositionSigmaIndirect: 2
+ m_PVRFilteringAtrousPositionSigmaAO: 1
+ m_ShowResolutionOverlay: 1
+ m_LightingDataAsset: {fileID: 0}
+ m_UseShadowmask: 1
+--- !u!196 &4
+NavMeshSettings:
+ serializedVersion: 2
+ m_ObjectHideFlags: 0
+ m_BuildSettings:
+ serializedVersion: 2
+ agentTypeID: 0
+ agentRadius: 0.5
+ agentHeight: 2
+ agentSlope: 45
+ agentClimb: 0.4
+ ledgeDropHeight: 0
+ maxJumpAcrossDistance: 0
+ minRegionArea: 2
+ manualCellSize: 0
+ cellSize: 0.16666667
+ manualTileSize: 0
+ tileSize: 256
+ accuratePlacement: 0
+ debug:
+ m_Flags: 0
+ m_NavMeshData: {fileID: 0}
+--- !u!1 &492081941
+GameObject:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ serializedVersion: 5
+ m_Component:
+ - component: {fileID: 492081945}
+ - component: {fileID: 492081944}
+ - component: {fileID: 492081943}
+ - component: {fileID: 492081942}
+ m_Layer: 0
+ m_Name: Main Camera
+ m_TagString: MainCamera
+ m_Icon: {fileID: 0}
+ m_NavMeshLayer: 0
+ m_StaticEditorFlags: 0
+ m_IsActive: 1
+--- !u!81 &492081942
+AudioListener:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 492081941}
+ m_Enabled: 1
+--- !u!124 &492081943
+Behaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 492081941}
+ m_Enabled: 1
+--- !u!20 &492081944
+Camera:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 492081941}
+ m_Enabled: 1
+ serializedVersion: 2
+ m_ClearFlags: 1
+ m_BackGroundColor: {r: 0.19215687, g: 0.3019608, b: 0.4745098, a: 0}
+ m_NormalizedViewPortRect:
+ serializedVersion: 2
+ x: 0
+ y: 0
+ width: 1
+ height: 1
+ near clip plane: 0.3
+ far clip plane: 1000
+ field of view: 60
+ orthographic: 1
+ orthographic size: 5
+ m_Depth: -1
+ m_CullingMask:
+ serializedVersion: 2
+ m_Bits: 4294967295
+ m_RenderingPath: -1
+ m_TargetTexture: {fileID: 0}
+ m_TargetDisplay: 0
+ m_TargetEye: 3
+ m_HDR: 1
+ m_AllowMSAA: 1
+ m_AllowDynamicResolution: 0
+ m_ForceIntoRT: 0
+ m_OcclusionCulling: 1
+ m_StereoConvergence: 10
+ m_StereoSeparation: 0.022
+--- !u!4 &492081945
+Transform:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 492081941}
+ m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
+ m_LocalPosition: {x: 0, y: 0, z: -10}
+ m_LocalScale: {x: 1, y: 1, z: 1}
+ m_Children:
+ - {fileID: 904015944}
+ m_Father: {fileID: 0}
+ m_RootOrder: 0
+ m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
+--- !u!1 &904015943
+GameObject:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ serializedVersion: 5
+ m_Component:
+ - component: {fileID: 904015944}
+ - component: {fileID: 904015945}
+ m_Layer: 0
+ m_Name: HelloTFLite
+ m_TagString: Untagged
+ m_Icon: {fileID: 0}
+ m_NavMeshLayer: 0
+ m_StaticEditorFlags: 0
+ m_IsActive: 1
+--- !u!4 &904015944
+Transform:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 904015943}
+ m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
+ m_LocalPosition: {x: 0, y: 0, z: 0}
+ m_LocalScale: {x: 1, y: 1, z: 1}
+ m_Children: []
+ m_Father: {fileID: 492081945}
+ m_RootOrder: 0
+ m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
+--- !u!114 &904015945
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 904015943}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 899510441e0ca4be0879d3055e467878, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ model: {fileID: 4900000, guid: adff4e1dbdba344c199ee4fe7e84457e, type: 3}
+ inputs:
+ - 1
+ - 3
+ - 7
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta
new file mode 100644
index 0000000000..e1e13efb66
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: f8a8c37a396584bb7b21687f33d6d3f8
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes
new file mode 100644
index 0000000000..aef0fe3d82
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes
Binary files differ
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta
new file mode 100644
index 0000000000..ba24871413
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: adff4e1dbdba344c199ee4fe7e84457e
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta
new file mode 100644
index 0000000000..28fde68b8b
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: f7d1e2dec09b64acdb7b8f5aef9fcb44
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
new file mode 100644
index 0000000000..abca814499
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using TensorFlowLite;
+using UnityEngine;
+
+/// <summary>
+/// Simple example demonstrating use of the experimental C# bindings for TensorFlowLite.
+/// </summary>
+public class HelloTFLite : MonoBehaviour {
+
+ [Tooltip("Configurable TFLite model.")]
+ public TextAsset model;
+
+ [Tooltip("Configurable TFLite input tensor data.")]
+ public float[] inputs;
+
+ private Interpreter interpreter;
+ private float[] outputs;
+
+ void Start () {
+ interpreter = new Interpreter(model.bytes);
+ Debug.LogFormat("InputCount: {0}, OutputCount: {1}",
+ interpreter.GetInputTensorCount(),
+ interpreter.GetOutputTensorCount());
+ }
+
+ void Update () {
+ if (inputs == null) {
+ return;
+ }
+
+ if (outputs == null || outputs.Length != inputs.Length) {
+ interpreter.ResizeInputTensor(0, new int[]{inputs.Length});
+ interpreter.AllocateTensors();
+ outputs = new float[inputs.Length];
+ }
+
+ interpreter.SetInputTensorData(0, inputs);
+ interpreter.Invoke();
+ interpreter.GetOutputTensorData(0, outputs);
+
+ Debug.LogFormat("Input: {0}, Output: {1}",
+ ArrayToString(inputs),
+ ArrayToString(outputs));
+ }
+
+ void OnDestroy() {
+ interpreter.Dispose();
+ }
+
+ private static string ArrayToString(float[] values) {
+ return string.Join(",", values.Select(x => x.ToString()).ToArray());
+ }
+}
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta
new file mode 100644
index 0000000000..ba83f45084
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 899510441e0ca4be0879d3055e467878
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta
new file mode 100644
index 0000000000..bf5ce15c6a
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 16dad1655bcdc48f7b325a2a634b9c69
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta
new file mode 100644
index 0000000000..22ed2c466b
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: d70863368f8904d509a9b73d3a555914
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
new file mode 100644
index 0000000000..ab966bae2e
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+using System;
+using System.Runtime.InteropServices;
+
+using TFL_Interpreter = System.IntPtr;
+using TFL_Tensor = System.IntPtr;
+
+namespace TensorFlowLite
+{
+ /// <summary>
+ /// Simple C# bindings for the experimental TensorFlowLite C API.
+ /// </summary>
+ public class Interpreter : IDisposable
+ {
+ private const string TensorFlowLibrary = "tensorflowlite_c";
+
+ private TFL_Interpreter handle;
+
+ public Interpreter(byte[] modelData) {
+ GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
+ IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
+ handle = TFL_NewInterpreter(modelDataPtr, modelData.Length);
+ if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
+ }
+
+ ~Interpreter() {
+ Dispose();
+ }
+
+ public void Dispose() {
+ if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle);
+ handle = IntPtr.Zero;
+ }
+
+ public void Invoke() {
+ ThrowIfError(TFL_InterpreterInvoke(handle));
+ }
+
+ public int GetInputTensorCount() {
+ return TFL_InterpreterGetInputTensorCount(handle);
+ }
+
+ public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
+ GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
+ IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
+ TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex);
+ ThrowIfError(TFL_TensorCopyFromBuffer(
+ tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
+ }
+
+ public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
+ ThrowIfError(TFL_InterpreterResizeInputTensor(
+ handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
+ }
+
+ public void AllocateTensors() {
+ ThrowIfError(TFL_InterpreterAllocateTensors(handle));
+ }
+
+ public int GetOutputTensorCount() {
+ return TFL_InterpreterGetOutputTensorCount(handle);
+ }
+
+ public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
+ GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
+ IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
+ TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex);
+ ThrowIfError(TFL_TensorCopyToBuffer(
+ tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
+ }
+
+ private static void ThrowIfError(int resultCode) {
+ if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
+ }
+
+ #region Externs
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TFL_Interpreter TFL_NewInterpreter(
+ IntPtr model_data,
+ int model_size);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_InterpreterGetInputTensorCount(
+ TFL_Interpreter interpreter);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor(
+ TFL_Interpreter interpreter,
+ int input_index);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_InterpreterResizeInputTensor(
+ TFL_Interpreter interpreter,
+ int input_index,
+ int[] input_dims,
+ int input_dims_size);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_InterpreterAllocateTensors(
+ TFL_Interpreter interpreter);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_InterpreterGetOutputTensorCount(
+ TFL_Interpreter interpreter);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor(
+ TFL_Interpreter interpreter,
+ int output_index);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_TensorCopyFromBuffer(
+ TFL_Tensor tensor,
+ IntPtr input_data,
+ int input_data_size);
+
+ [DllImport (TensorFlowLibrary)]
+ private static extern unsafe int TFL_TensorCopyToBuffer(
+ TFL_Tensor tensor,
+ IntPtr output_data,
+ int output_data_size);
+
+ #endregion
+ }
+}
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta
new file mode 100644
index 0000000000..5ec84ef7f7
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 0bbaf59e6ac914ed1b28174fb9008a09
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset
new file mode 100644
index 0000000000..da6112576a
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset
@@ -0,0 +1,17 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!11 &1
+AudioManager:
+ m_ObjectHideFlags: 0
+ m_Volume: 1
+ Rolloff Scale: 1
+ Doppler Factor: 1
+ Default Speaker Mode: 2
+ m_SampleRate: 0
+ m_DSPBufferSize: 0
+ m_VirtualVoiceCount: 512
+ m_RealVoiceCount: 32
+ m_SpatializerPlugin:
+ m_AmbisonicDecoderPlugin:
+ m_DisableAudio: 0
+ m_VirtualizeEffects: 1
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset
new file mode 100644
index 0000000000..e7886b266a
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset
@@ -0,0 +1,6 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!236 &1
+ClusterInputManager:
+ m_ObjectHideFlags: 0
+ m_Inputs: []
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset
new file mode 100644
index 0000000000..78992f08c7
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset
@@ -0,0 +1,29 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!55 &1
+PhysicsManager:
+ m_ObjectHideFlags: 0
+ serializedVersion: 7
+ m_Gravity: {x: 0, y: -9.81, z: 0}
+ m_DefaultMaterial: {fileID: 0}
+ m_BounceThreshold: 2
+ m_SleepThreshold: 0.005
+ m_DefaultContactOffset: 0.01
+ m_DefaultSolverIterations: 6
+ m_DefaultSolverVelocityIterations: 1
+ m_QueriesHitBackfaces: 0
+ m_QueriesHitTriggers: 1
+ m_EnableAdaptiveForce: 0
+ m_ClothInterCollisionDistance: 0
+ m_ClothInterCollisionStiffness: 0
+ m_ContactsGeneration: 1
+ m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
+ m_AutoSimulation: 1
+ m_AutoSyncTransforms: 1
+ m_ClothInterCollisionSettingsToggle: 0
+ m_ContactPairsMode: 0
+ m_BroadphaseType: 0
+ m_WorldBounds:
+ m_Center: {x: 0, y: 0, z: 0}
+ m_Extent: {x: 250, y: 250, z: 250}
+ m_WorldSubdivisions: 8
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset
new file mode 100644
index 0000000000..6dc24f7dfd
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset
@@ -0,0 +1,7 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!1045 &1
+EditorBuildSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 2
+ m_Scenes: []
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset
new file mode 100644
index 0000000000..fcd016402f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset
@@ -0,0 +1,21 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!159 &1
+EditorSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 7
+ m_ExternalVersionControlSupport: Visible Meta Files
+ m_SerializationMode: 2
+ m_LineEndingsForNewScripts: 1
+ m_DefaultBehaviorMode: 1
+ m_SpritePackerMode: 4
+ m_SpritePackerPaddingPower: 1
+ m_EtcTextureCompressorBehavior: 1
+ m_EtcTextureFastCompressor: 1
+ m_EtcTextureNormalCompressor: 2
+ m_EtcTextureBestCompressor: 4
+ m_ProjectGenerationIncludedExtensions: txt;xml;fnt;cd;asmdef;rsp
+ m_ProjectGenerationRootNamespace:
+ m_UserGeneratedProjectSuffix:
+ m_CollabEditorSettings:
+ inProgressEnabled: 1
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset
new file mode 100644
index 0000000000..74d7b532b0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset
@@ -0,0 +1,61 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!30 &1
+GraphicsSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 12
+ m_Deferred:
+ m_Mode: 1
+ m_Shader: {fileID: 69, guid: 0000000000000000f000000000000000, type: 0}
+ m_DeferredReflections:
+ m_Mode: 1
+ m_Shader: {fileID: 74, guid: 0000000000000000f000000000000000, type: 0}
+ m_ScreenSpaceShadows:
+ m_Mode: 1
+ m_Shader: {fileID: 64, guid: 0000000000000000f000000000000000, type: 0}
+ m_LegacyDeferred:
+ m_Mode: 1
+ m_Shader: {fileID: 63, guid: 0000000000000000f000000000000000, type: 0}
+ m_DepthNormals:
+ m_Mode: 1
+ m_Shader: {fileID: 62, guid: 0000000000000000f000000000000000, type: 0}
+ m_MotionVectors:
+ m_Mode: 1
+ m_Shader: {fileID: 75, guid: 0000000000000000f000000000000000, type: 0}
+ m_LightHalo:
+ m_Mode: 1
+ m_Shader: {fileID: 105, guid: 0000000000000000f000000000000000, type: 0}
+ m_LensFlare:
+ m_Mode: 1
+ m_Shader: {fileID: 102, guid: 0000000000000000f000000000000000, type: 0}
+ m_AlwaysIncludedShaders:
+ - {fileID: 7, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 15104, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 15105, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 15106, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 10753, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 10770, guid: 0000000000000000f000000000000000, type: 0}
+ m_PreloadedShaders: []
+ m_SpritesDefaultMaterial: {fileID: 10754, guid: 0000000000000000f000000000000000,
+ type: 0}
+ m_CustomRenderPipeline: {fileID: 0}
+ m_TransparencySortMode: 0
+ m_TransparencySortAxis: {x: 0, y: 0, z: 1}
+ m_DefaultRenderingPath: 1
+ m_DefaultMobileRenderingPath: 1
+ m_TierSettings: []
+ m_LightmapStripping: 0
+ m_FogStripping: 0
+ m_InstancingStripping: 0
+ m_LightmapKeepPlain: 1
+ m_LightmapKeepDirCombined: 1
+ m_LightmapKeepDynamicPlain: 1
+ m_LightmapKeepDynamicDirCombined: 1
+ m_LightmapKeepShadowMask: 1
+ m_LightmapKeepSubtractive: 1
+ m_FogKeepLinear: 1
+ m_FogKeepExp: 1
+ m_FogKeepExp2: 1
+ m_AlbedoSwatchInfos: []
+ m_LightsUseLinearIntensity: 0
+ m_LightsUseColorTemperature: 0
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset
new file mode 100644
index 0000000000..17c8f538e2
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset
@@ -0,0 +1,295 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!13 &1
+InputManager:
+ m_ObjectHideFlags: 0
+ serializedVersion: 2
+ m_Axes:
+ - serializedVersion: 3
+ m_Name: Horizontal
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton: left
+ positiveButton: right
+ altNegativeButton: a
+ altPositiveButton: d
+ gravity: 3
+ dead: 0.001
+ sensitivity: 3
+ snap: 1
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Vertical
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton: down
+ positiveButton: up
+ altNegativeButton: s
+ altPositiveButton: w
+ gravity: 3
+ dead: 0.001
+ sensitivity: 3
+ snap: 1
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire1
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: left ctrl
+ altNegativeButton:
+ altPositiveButton: mouse 0
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire2
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: left alt
+ altNegativeButton:
+ altPositiveButton: mouse 1
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire3
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: left shift
+ altNegativeButton:
+ altPositiveButton: mouse 2
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Jump
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: space
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Mouse X
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton:
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 0
+ dead: 0
+ sensitivity: 0.1
+ snap: 0
+ invert: 0
+ type: 1
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Mouse Y
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton:
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 0
+ dead: 0
+ sensitivity: 0.1
+ snap: 0
+ invert: 0
+ type: 1
+ axis: 1
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Mouse ScrollWheel
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton:
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 0
+ dead: 0
+ sensitivity: 0.1
+ snap: 0
+ invert: 0
+ type: 1
+ axis: 2
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Horizontal
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton:
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 0
+ dead: 0.19
+ sensitivity: 1
+ snap: 0
+ invert: 0
+ type: 2
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Vertical
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton:
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 0
+ dead: 0.19
+ sensitivity: 1
+ snap: 0
+ invert: 1
+ type: 2
+ axis: 1
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire1
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: joystick button 0
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire2
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: joystick button 1
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Fire3
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: joystick button 2
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Jump
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: joystick button 3
+ altNegativeButton:
+ altPositiveButton:
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Submit
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: return
+ altNegativeButton:
+ altPositiveButton: joystick button 0
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Submit
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: enter
+ altNegativeButton:
+ altPositiveButton: space
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
+ - serializedVersion: 3
+ m_Name: Cancel
+ descriptiveName:
+ descriptiveNegativeName:
+ negativeButton:
+ positiveButton: escape
+ altNegativeButton:
+ altPositiveButton: joystick button 1
+ gravity: 1000
+ dead: 0.001
+ sensitivity: 1000
+ snap: 0
+ invert: 0
+ type: 0
+ axis: 0
+ joyNum: 0
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset
new file mode 100644
index 0000000000..3b0b7c3d18
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset
@@ -0,0 +1,91 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!126 &1
+NavMeshProjectSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 2
+ areas:
+ - name: Walkable
+ cost: 1
+ - name: Not Walkable
+ cost: 1
+ - name: Jump
+ cost: 2
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ - name:
+ cost: 1
+ m_LastAgentTypeID: -887442657
+ m_Settings:
+ - serializedVersion: 2
+ agentTypeID: 0
+ agentRadius: 0.5
+ agentHeight: 2
+ agentSlope: 45
+ agentClimb: 0.75
+ ledgeDropHeight: 0
+ maxJumpAcrossDistance: 0
+ minRegionArea: 2
+ manualCellSize: 0
+ cellSize: 0.16666667
+ manualTileSize: 0
+ tileSize: 256
+ accuratePlacement: 0
+ debug:
+ m_Flags: 0
+ m_SettingNames:
+ - Humanoid
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset
new file mode 100644
index 0000000000..5dc6a831d9
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset
@@ -0,0 +1,8 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!149 &1
+NetworkManager:
+ m_ObjectHideFlags: 0
+ m_DebugLevel: 0
+ m_Sendrate: 15
+ m_AssetToPrefab: {}
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset
new file mode 100644
index 0000000000..132ee6bc86
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset
@@ -0,0 +1,37 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!19 &1
+Physics2DSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 3
+ m_Gravity: {x: 0, y: -9.81}
+ m_DefaultMaterial: {fileID: 0}
+ m_VelocityIterations: 8
+ m_PositionIterations: 3
+ m_VelocityThreshold: 1
+ m_MaxLinearCorrection: 0.2
+ m_MaxAngularCorrection: 8
+ m_MaxTranslationSpeed: 100
+ m_MaxRotationSpeed: 360
+ m_BaumgarteScale: 0.2
+ m_BaumgarteTimeOfImpactScale: 0.75
+ m_TimeToSleep: 0.5
+ m_LinearSleepTolerance: 0.01
+ m_AngularSleepTolerance: 2
+ m_DefaultContactOffset: 0.01
+ m_AutoSimulation: 1
+ m_QueriesHitTriggers: 1
+ m_QueriesStartInColliders: 1
+ m_ChangeStopsCallbacks: 0
+ m_CallbacksOnDisable: 1
+ m_AutoSyncTransforms: 1
+ m_AlwaysShowColliders: 0
+ m_ShowColliderSleep: 1
+ m_ShowColliderContacts: 0
+ m_ShowColliderAABB: 0
+ m_ContactArrowScale: 0.2
+ m_ColliderAwakeColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.7529412}
+ m_ColliderAsleepColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.36078432}
+ m_ColliderContactColor: {r: 1, g: 0, b: 1, a: 0.6862745}
+ m_ColliderAABBColor: {r: 1, g: 1, b: 0, a: 0.2509804}
+ m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset
new file mode 100644
index 0000000000..3fbfab76c1
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset
@@ -0,0 +1,641 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!129 &1
+PlayerSettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 14
+ productGUID: a084943b991dd4597b140f4ce2b41c65
+ AndroidProfiler: 0
+ AndroidFilterTouchesWhenObscured: 0
+ defaultScreenOrientation: 4
+ targetDevice: 2
+ useOnDemandResources: 0
+ accelerometerFrequency: 60
+ companyName: DefaultCompany
+ productName: TensorFlowLitePlugin
+ defaultCursor: {fileID: 0}
+ cursorHotspot: {x: 0, y: 0}
+ m_SplashScreenBackgroundColor: {r: 0.13725491, g: 0.12156863, b: 0.1254902, a: 1}
+ m_ShowUnitySplashScreen: 1
+ m_ShowUnitySplashLogo: 1
+ m_SplashScreenOverlayOpacity: 1
+ m_SplashScreenAnimation: 1
+ m_SplashScreenLogoStyle: 1
+ m_SplashScreenDrawMode: 0
+ m_SplashScreenBackgroundAnimationZoom: 1
+ m_SplashScreenLogoAnimationZoom: 1
+ m_SplashScreenBackgroundLandscapeAspect: 1
+ m_SplashScreenBackgroundPortraitAspect: 1
+ m_SplashScreenBackgroundLandscapeUvs:
+ serializedVersion: 2
+ x: 0
+ y: 0
+ width: 1
+ height: 1
+ m_SplashScreenBackgroundPortraitUvs:
+ serializedVersion: 2
+ x: 0
+ y: 0
+ width: 1
+ height: 1
+ m_SplashScreenLogos: []
+ m_VirtualRealitySplashScreen: {fileID: 0}
+ m_HolographicTrackingLossScreen: {fileID: 0}
+ defaultScreenWidth: 1024
+ defaultScreenHeight: 768
+ defaultScreenWidthWeb: 960
+ defaultScreenHeightWeb: 600
+ m_StereoRenderingPath: 0
+ m_ActiveColorSpace: 0
+ m_MTRendering: 1
+ m_StackTraceTypes: 010000000100000001000000010000000100000001000000
+ iosShowActivityIndicatorOnLoading: -1
+ androidShowActivityIndicatorOnLoading: -1
+ tizenShowActivityIndicatorOnLoading: -1
+ iosAppInBackgroundBehavior: 0
+ displayResolutionDialog: 1
+ iosAllowHTTPDownload: 1
+ allowedAutorotateToPortrait: 1
+ allowedAutorotateToPortraitUpsideDown: 1
+ allowedAutorotateToLandscapeRight: 1
+ allowedAutorotateToLandscapeLeft: 1
+ useOSAutorotation: 1
+ use32BitDisplayBuffer: 1
+ preserveFramebufferAlpha: 0
+ disableDepthAndStencilBuffers: 0
+ androidBlitType: 0
+ defaultIsFullScreen: 1
+ defaultIsNativeResolution: 1
+ macRetinaSupport: 1
+ runInBackground: 0
+ captureSingleScreen: 0
+ muteOtherAudioSources: 0
+ Prepare IOS For Recording: 0
+ Force IOS Speakers When Recording: 0
+ deferSystemGesturesMode: 0
+ hideHomeButton: 0
+ submitAnalytics: 1
+ usePlayerLog: 1
+ bakeCollisionMeshes: 0
+ forceSingleInstance: 0
+ resizableWindow: 0
+ useMacAppStoreValidation: 0
+ macAppStoreCategory: public.app-category.games
+ gpuSkinning: 0
+ graphicsJobs: 0
+ xboxPIXTextureCapture: 0
+ xboxEnableAvatar: 0
+ xboxEnableKinect: 0
+ xboxEnableKinectAutoTracking: 0
+ xboxEnableFitness: 0
+ visibleInBackground: 1
+ allowFullscreenSwitch: 1
+ graphicsJobMode: 0
+ macFullscreenMode: 2
+ d3d11FullscreenMode: 1
+ xboxSpeechDB: 0
+ xboxEnableHeadOrientation: 0
+ xboxEnableGuest: 0
+ xboxEnablePIXSampling: 0
+ metalFramebufferOnly: 0
+ n3dsDisableStereoscopicView: 0
+ n3dsEnableSharedListOpt: 1
+ n3dsEnableVSync: 0
+ xboxOneResolution: 0
+ xboxOneSResolution: 0
+ xboxOneXResolution: 3
+ xboxOneMonoLoggingLevel: 0
+ xboxOneLoggingLevel: 1
+ xboxOneDisableEsram: 0
+ xboxOnePresentImmediateThreshold: 0
+ videoMemoryForVertexBuffers: 0
+ psp2PowerMode: 0
+ psp2AcquireBGM: 1
+ wiiUTVResolution: 0
+ wiiUGamePadMSAA: 1
+ wiiUSupportsNunchuk: 0
+ wiiUSupportsClassicController: 0
+ wiiUSupportsBalanceBoard: 0
+ wiiUSupportsMotionPlus: 0
+ wiiUSupportsProController: 0
+ wiiUAllowScreenCapture: 1
+ wiiUControllerCount: 0
+ m_SupportedAspectRatios:
+ 4:3: 1
+ 5:4: 1
+ 16:10: 1
+ 16:9: 1
+ Others: 1
+ bundleVersion: 1.0
+ preloadedAssets: []
+ metroInputSource: 0
+ wsaTransparentSwapchain: 0
+ m_HolographicPauseOnTrackingLoss: 1
+ xboxOneDisableKinectGpuReservation: 0
+ xboxOneEnable7thCore: 0
+ vrSettings:
+ cardboard:
+ depthFormat: 0
+ enableTransitionView: 0
+ daydream:
+ depthFormat: 0
+ useSustainedPerformanceMode: 0
+ enableVideoLayer: 0
+ useProtectedVideoMemory: 0
+ minimumSupportedHeadTracking: 0
+ maximumSupportedHeadTracking: 1
+ hololens:
+ depthFormat: 1
+ depthBufferSharingEnabled: 0
+ oculus:
+ sharedDepthBuffer: 0
+ dashSupport: 0
+ protectGraphicsMemory: 0
+ useHDRDisplay: 0
+ m_ColorGamuts: 00000000
+ targetPixelDensity: 30
+ resolutionScalingMode: 0
+ androidSupportedAspectRatio: 1
+ androidMaxAspectRatio: 2.1
+ applicationIdentifier: {}
+ buildNumber: {}
+ AndroidBundleVersionCode: 1
+ AndroidMinSdkVersion: 16
+ AndroidTargetSdkVersion: 0
+ AndroidPreferredInstallLocation: 1
+ aotOptions:
+ stripEngineCode: 1
+ iPhoneStrippingLevel: 0
+ iPhoneScriptCallOptimization: 0
+ ForceInternetPermission: 0
+ ForceSDCardPermission: 0
+ CreateWallpaper: 0
+ APKExpansionFiles: 0
+ keepLoadedShadersAlive: 0
+ StripUnusedMeshComponents: 0
+ VertexChannelCompressionMask:
+ serializedVersion: 2
+ m_Bits: 238
+ iPhoneSdkVersion: 988
+ iOSTargetOSVersionString: 7.0
+ tvOSSdkVersion: 0
+ tvOSRequireExtendedGameController: 0
+ tvOSTargetOSVersionString: 9.0
+ uIPrerenderedIcon: 0
+ uIRequiresPersistentWiFi: 0
+ uIRequiresFullScreen: 1
+ uIStatusBarHidden: 1
+ uIExitOnSuspend: 0
+ uIStatusBarStyle: 0
+ iPhoneSplashScreen: {fileID: 0}
+ iPhoneHighResSplashScreen: {fileID: 0}
+ iPhoneTallHighResSplashScreen: {fileID: 0}
+ iPhone47inSplashScreen: {fileID: 0}
+ iPhone55inPortraitSplashScreen: {fileID: 0}
+ iPhone55inLandscapeSplashScreen: {fileID: 0}
+ iPhone58inPortraitSplashScreen: {fileID: 0}
+ iPhone58inLandscapeSplashScreen: {fileID: 0}
+ iPadPortraitSplashScreen: {fileID: 0}
+ iPadHighResPortraitSplashScreen: {fileID: 0}
+ iPadLandscapeSplashScreen: {fileID: 0}
+ iPadHighResLandscapeSplashScreen: {fileID: 0}
+ appleTVSplashScreen: {fileID: 0}
+ appleTVSplashScreen2x: {fileID: 0}
+ tvOSSmallIconLayers: []
+ tvOSSmallIconLayers2x: []
+ tvOSLargeIconLayers: []
+ tvOSTopShelfImageLayers: []
+ tvOSTopShelfImageLayers2x: []
+ tvOSTopShelfImageWideLayers: []
+ tvOSTopShelfImageWideLayers2x: []
+ iOSLaunchScreenType: 0
+ iOSLaunchScreenPortrait: {fileID: 0}
+ iOSLaunchScreenLandscape: {fileID: 0}
+ iOSLaunchScreenBackgroundColor:
+ serializedVersion: 2
+ rgba: 0
+ iOSLaunchScreenFillPct: 100
+ iOSLaunchScreenSize: 100
+ iOSLaunchScreenCustomXibPath:
+ iOSLaunchScreeniPadType: 0
+ iOSLaunchScreeniPadImage: {fileID: 0}
+ iOSLaunchScreeniPadBackgroundColor:
+ serializedVersion: 2
+ rgba: 0
+ iOSLaunchScreeniPadFillPct: 100
+ iOSLaunchScreeniPadSize: 100
+ iOSLaunchScreeniPadCustomXibPath:
+ iOSUseLaunchScreenStoryboard: 0
+ iOSLaunchScreenCustomStoryboardPath:
+ iOSDeviceRequirements: []
+ iOSURLSchemes: []
+ iOSBackgroundModes: 0
+ iOSMetalForceHardShadows: 0
+ metalEditorSupport: 1
+ metalAPIValidation: 1
+ iOSRenderExtraFrameOnPause: 0
+ appleDeveloperTeamID:
+ iOSManualSigningProvisioningProfileID:
+ tvOSManualSigningProvisioningProfileID:
+ appleEnableAutomaticSigning: 0
+ clonedFromGUID: 00000000000000000000000000000000
+ AndroidTargetDevice: 0
+ AndroidSplashScreenScale: 0
+ androidSplashScreen: {fileID: 0}
+ AndroidKeystoreName:
+ AndroidKeyaliasName:
+ AndroidTVCompatibility: 1
+ AndroidIsGame: 1
+ AndroidEnableTango: 0
+ androidEnableBanner: 1
+ androidUseLowAccuracyLocation: 0
+ m_AndroidBanners:
+ - width: 320
+ height: 180
+ banner: {fileID: 0}
+ androidGamepadSupportLevel: 0
+ resolutionDialogBanner: {fileID: 0}
+ m_BuildTargetIcons: []
+ m_BuildTargetBatching: []
+ m_BuildTargetGraphicsAPIs: []
+ m_BuildTargetVRSettings: []
+ m_BuildTargetEnableVuforiaSettings: []
+ openGLRequireES31: 0
+ openGLRequireES31AEP: 0
+ m_TemplateCustomTags: {}
+ mobileMTRendering:
+ Android: 1
+ iPhone: 1
+ tvOS: 1
+ m_BuildTargetGroupLightmapEncodingQuality: []
+ wiiUTitleID: 0005000011000000
+ wiiUGroupID: 00010000
+ wiiUCommonSaveSize: 4096
+ wiiUAccountSaveSize: 2048
+ wiiUOlvAccessKey: 0
+ wiiUTinCode: 0
+ wiiUJoinGameId: 0
+ wiiUJoinGameModeMask: 0000000000000000
+ wiiUCommonBossSize: 0
+ wiiUAccountBossSize: 0
+ wiiUAddOnUniqueIDs: []
+ wiiUMainThreadStackSize: 3072
+ wiiULoaderThreadStackSize: 1024
+ wiiUSystemHeapSize: 128
+ wiiUTVStartupScreen: {fileID: 0}
+ wiiUGamePadStartupScreen: {fileID: 0}
+ wiiUDrcBufferDisabled: 0
+ wiiUProfilerLibPath:
+ playModeTestRunnerEnabled: 0
+ actionOnDotNetUnhandledException: 1
+ enableInternalProfiler: 0
+ logObjCUncaughtExceptions: 1
+ enableCrashReportAPI: 0
+ cameraUsageDescription:
+ locationUsageDescription:
+ microphoneUsageDescription:
+ switchNetLibKey:
+ switchSocketMemoryPoolSize: 6144
+ switchSocketAllocatorPoolSize: 128
+ switchSocketConcurrencyLimit: 14
+ switchScreenResolutionBehavior: 2
+ switchUseCPUProfiler: 0
+ switchApplicationID: 0x01004b9000490000
+ switchNSODependencies:
+ switchTitleNames_0:
+ switchTitleNames_1:
+ switchTitleNames_2:
+ switchTitleNames_3:
+ switchTitleNames_4:
+ switchTitleNames_5:
+ switchTitleNames_6:
+ switchTitleNames_7:
+ switchTitleNames_8:
+ switchTitleNames_9:
+ switchTitleNames_10:
+ switchTitleNames_11:
+ switchTitleNames_12:
+ switchTitleNames_13:
+ switchTitleNames_14:
+ switchPublisherNames_0:
+ switchPublisherNames_1:
+ switchPublisherNames_2:
+ switchPublisherNames_3:
+ switchPublisherNames_4:
+ switchPublisherNames_5:
+ switchPublisherNames_6:
+ switchPublisherNames_7:
+ switchPublisherNames_8:
+ switchPublisherNames_9:
+ switchPublisherNames_10:
+ switchPublisherNames_11:
+ switchPublisherNames_12:
+ switchPublisherNames_13:
+ switchPublisherNames_14:
+ switchIcons_0: {fileID: 0}
+ switchIcons_1: {fileID: 0}
+ switchIcons_2: {fileID: 0}
+ switchIcons_3: {fileID: 0}
+ switchIcons_4: {fileID: 0}
+ switchIcons_5: {fileID: 0}
+ switchIcons_6: {fileID: 0}
+ switchIcons_7: {fileID: 0}
+ switchIcons_8: {fileID: 0}
+ switchIcons_9: {fileID: 0}
+ switchIcons_10: {fileID: 0}
+ switchIcons_11: {fileID: 0}
+ switchIcons_12: {fileID: 0}
+ switchIcons_13: {fileID: 0}
+ switchIcons_14: {fileID: 0}
+ switchSmallIcons_0: {fileID: 0}
+ switchSmallIcons_1: {fileID: 0}
+ switchSmallIcons_2: {fileID: 0}
+ switchSmallIcons_3: {fileID: 0}
+ switchSmallIcons_4: {fileID: 0}
+ switchSmallIcons_5: {fileID: 0}
+ switchSmallIcons_6: {fileID: 0}
+ switchSmallIcons_7: {fileID: 0}
+ switchSmallIcons_8: {fileID: 0}
+ switchSmallIcons_9: {fileID: 0}
+ switchSmallIcons_10: {fileID: 0}
+ switchSmallIcons_11: {fileID: 0}
+ switchSmallIcons_12: {fileID: 0}
+ switchSmallIcons_13: {fileID: 0}
+ switchSmallIcons_14: {fileID: 0}
+ switchManualHTML:
+ switchAccessibleURLs:
+ switchLegalInformation:
+ switchMainThreadStackSize: 1048576
+ switchPresenceGroupId:
+ switchLogoHandling: 0
+ switchReleaseVersion: 0
+ switchDisplayVersion: 1.0.0
+ switchStartupUserAccount: 0
+ switchTouchScreenUsage: 0
+ switchSupportedLanguagesMask: 0
+ switchLogoType: 0
+ switchApplicationErrorCodeCategory:
+ switchUserAccountSaveDataSize: 0
+ switchUserAccountSaveDataJournalSize: 0
+ switchApplicationAttribute: 0
+ switchCardSpecSize: -1
+ switchCardSpecClock: -1
+ switchRatingsMask: 0
+ switchRatingsInt_0: 0
+ switchRatingsInt_1: 0
+ switchRatingsInt_2: 0
+ switchRatingsInt_3: 0
+ switchRatingsInt_4: 0
+ switchRatingsInt_5: 0
+ switchRatingsInt_6: 0
+ switchRatingsInt_7: 0
+ switchRatingsInt_8: 0
+ switchRatingsInt_9: 0
+ switchRatingsInt_10: 0
+ switchRatingsInt_11: 0
+ switchLocalCommunicationIds_0:
+ switchLocalCommunicationIds_1:
+ switchLocalCommunicationIds_2:
+ switchLocalCommunicationIds_3:
+ switchLocalCommunicationIds_4:
+ switchLocalCommunicationIds_5:
+ switchLocalCommunicationIds_6:
+ switchLocalCommunicationIds_7:
+ switchParentalControl: 0
+ switchAllowsScreenshot: 1
+ switchAllowsVideoCapturing: 1
+ switchAllowsRuntimeAddOnContentInstall: 0
+ switchDataLossConfirmation: 0
+ switchSupportedNpadStyles: 3
+ switchSocketConfigEnabled: 0
+ switchTcpInitialSendBufferSize: 32
+ switchTcpInitialReceiveBufferSize: 64
+ switchTcpAutoSendBufferSizeMax: 256
+ switchTcpAutoReceiveBufferSizeMax: 256
+ switchUdpSendBufferSize: 9
+ switchUdpReceiveBufferSize: 42
+ switchSocketBufferEfficiency: 4
+ switchSocketInitializeEnabled: 1
+ switchNetworkInterfaceManagerInitializeEnabled: 1
+ switchPlayerConnectionEnabled: 1
+ ps4NPAgeRating: 12
+ ps4NPTitleSecret:
+ ps4NPTrophyPackPath:
+ ps4ParentalLevel: 11
+ ps4ContentID: ED1633-NPXX51362_00-0000000000000000
+ ps4Category: 0
+ ps4MasterVersion: 01.00
+ ps4AppVersion: 01.00
+ ps4AppType: 0
+ ps4ParamSfxPath:
+ ps4VideoOutPixelFormat: 0
+ ps4VideoOutInitialWidth: 1920
+ ps4VideoOutBaseModeInitialWidth: 1920
+ ps4VideoOutReprojectionRate: 60
+ ps4PronunciationXMLPath:
+ ps4PronunciationSIGPath:
+ ps4BackgroundImagePath:
+ ps4StartupImagePath:
+ ps4StartupImagesFolder:
+ ps4IconImagesFolder:
+ ps4SaveDataImagePath:
+ ps4SdkOverride:
+ ps4BGMPath:
+ ps4ShareFilePath:
+ ps4ShareOverlayImagePath:
+ ps4PrivacyGuardImagePath:
+ ps4NPtitleDatPath:
+ ps4RemotePlayKeyAssignment: -1
+ ps4RemotePlayKeyMappingDir:
+ ps4PlayTogetherPlayerCount: 0
+ ps4EnterButtonAssignment: 1
+ ps4ApplicationParam1: 0
+ ps4ApplicationParam2: 0
+ ps4ApplicationParam3: 0
+ ps4ApplicationParam4: 0
+ ps4DownloadDataSize: 0
+ ps4GarlicHeapSize: 2048
+ ps4ProGarlicHeapSize: 2560
+ ps4Passcode: d3hjjul8UhK6ZnQCEBYYQPozR9sQV066
+ ps4pnSessions: 1
+ ps4pnPresence: 1
+ ps4pnFriends: 1
+ ps4pnGameCustomData: 1
+ playerPrefsSupport: 0
+ restrictedAudioUsageRights: 0
+ ps4UseResolutionFallback: 0
+ ps4ReprojectionSupport: 0
+ ps4UseAudio3dBackend: 0
+ ps4SocialScreenEnabled: 0
+ ps4ScriptOptimizationLevel: 0
+ ps4Audio3dVirtualSpeakerCount: 14
+ ps4attribCpuUsage: 0
+ ps4PatchPkgPath:
+ ps4PatchLatestPkgPath:
+ ps4PatchChangeinfoPath:
+ ps4PatchDayOne: 0
+ ps4attribUserManagement: 0
+ ps4attribMoveSupport: 0
+ ps4attrib3DSupport: 0
+ ps4attribShareSupport: 0
+ ps4attribExclusiveVR: 0
+ ps4disableAutoHideSplash: 0
+ ps4videoRecordingFeaturesUsed: 0
+ ps4contentSearchFeaturesUsed: 0
+ ps4attribEyeToEyeDistanceSettingVR: 0
+ ps4IncludedModules: []
+ monoEnv:
+ psp2Splashimage: {fileID: 0}
+ psp2NPTrophyPackPath:
+ psp2NPSupportGBMorGJP: 0
+ psp2NPAgeRating: 12
+ psp2NPTitleDatPath:
+ psp2NPCommsID:
+ psp2NPCommunicationsID:
+ psp2NPCommsPassphrase:
+ psp2NPCommsSig:
+ psp2ParamSfxPath:
+ psp2ManualPath:
+ psp2LiveAreaGatePath:
+ psp2LiveAreaBackroundPath:
+ psp2LiveAreaPath:
+ psp2LiveAreaTrialPath:
+ psp2PatchChangeInfoPath:
+ psp2PatchOriginalPackage:
+ psp2PackagePassword: 3onkgZsAECEn0fzCoWiCtWCKe4l74pE5
+ psp2KeystoneFile:
+ psp2MemoryExpansionMode: 0
+ psp2DRMType: 0
+ psp2StorageType: 0
+ psp2MediaCapacity: 0
+ psp2DLCConfigPath:
+ psp2ThumbnailPath:
+ psp2BackgroundPath:
+ psp2SoundPath:
+ psp2TrophyCommId:
+ psp2TrophyPackagePath:
+ psp2PackagedResourcesPath:
+ psp2SaveDataQuota: 10240
+ psp2ParentalLevel: 1
+ psp2ShortTitle: Not Set
+ psp2ContentID: IV0000-ABCD12345_00-0123456789ABCDEF
+ psp2Category: 0
+ psp2MasterVersion: 01.00
+ psp2AppVersion: 01.00
+ psp2TVBootMode: 0
+ psp2EnterButtonAssignment: 2
+ psp2TVDisableEmu: 0
+ psp2AllowTwitterDialog: 1
+ psp2Upgradable: 0
+ psp2HealthWarning: 0
+ psp2UseLibLocation: 0
+ psp2InfoBarOnStartup: 0
+ psp2InfoBarColor: 0
+ psp2ScriptOptimizationLevel: 0
+ psmSplashimage: {fileID: 0}
+ splashScreenBackgroundSourceLandscape: {fileID: 0}
+ splashScreenBackgroundSourcePortrait: {fileID: 0}
+ spritePackerPolicy:
+ webGLMemorySize: 256
+ webGLExceptionSupport: 1
+ webGLNameFilesAsHashes: 0
+ webGLDataCaching: 0
+ webGLDebugSymbols: 0
+ webGLEmscriptenArgs:
+ webGLModulesDirectory:
+ webGLTemplate: APPLICATION:Default
+ webGLAnalyzeBuildSize: 0
+ webGLUseEmbeddedResources: 0
+ webGLUseWasm: 0
+ webGLCompressionFormat: 1
+ scriptingDefineSymbols: {}
+ platformArchitecture: {}
+ scriptingBackend: {}
+ incrementalIl2cppBuild: {}
+ additionalIl2CppArgs:
+ scriptingRuntimeVersion: 0
+ apiCompatibilityLevelPerPlatform: {}
+ m_RenderingPath: 1
+ m_MobileRenderingPath: 1
+ metroPackageName: TensorFlowLitePlugin
+ metroPackageVersion:
+ metroCertificatePath:
+ metroCertificatePassword:
+ metroCertificateSubject:
+ metroCertificateIssuer:
+ metroCertificateNotAfter: 0000000000000000
+ metroApplicationDescription: TensorFlowLitePlugin
+ wsaImages: {}
+ metroTileShortName:
+ metroCommandLineArgsFile:
+ metroTileShowName: 0
+ metroMediumTileShowName: 0
+ metroLargeTileShowName: 0
+ metroWideTileShowName: 0
+ metroDefaultTileSize: 1
+ metroTileForegroundText: 2
+ metroTileBackgroundColor: {r: 0.13333334, g: 0.17254902, b: 0.21568628, a: 0}
+ metroSplashScreenBackgroundColor: {r: 0.12941177, g: 0.17254902, b: 0.21568628,
+ a: 1}
+ metroSplashScreenUseBackgroundColor: 0
+ platformCapabilities: {}
+ metroFTAName:
+ metroFTAFileTypes: []
+ metroProtocolName:
+ metroCompilationOverrides: 1
+ tizenProductDescription:
+ tizenProductURL:
+ tizenSigningProfileName:
+ tizenGPSPermissions: 0
+ tizenMicrophonePermissions: 0
+ tizenDeploymentTarget:
+ tizenDeploymentTargetType: -1
+ tizenMinOSVersion: 1
+ n3dsUseExtSaveData: 0
+ n3dsCompressStaticMem: 1
+ n3dsExtSaveDataNumber: 0x12345
+ n3dsStackSize: 131072
+ n3dsTargetPlatform: 2
+ n3dsRegion: 7
+ n3dsMediaSize: 0
+ n3dsLogoStyle: 3
+ n3dsTitle: GameName
+ n3dsProductCode:
+ n3dsApplicationId: 0xFF3FF
+ XboxOneProductId:
+ XboxOneUpdateKey:
+ XboxOneSandboxId:
+ XboxOneContentId:
+ XboxOneTitleId:
+ XboxOneSCId:
+ XboxOneGameOsOverridePath:
+ XboxOnePackagingOverridePath:
+ XboxOneAppManifestOverridePath:
+ XboxOnePackageEncryption: 0
+ XboxOnePackageUpdateGranularity: 2
+ XboxOneDescription:
+ XboxOneLanguage:
+ - enus
+ XboxOneCapability: []
+ XboxOneGameRating: {}
+ XboxOneIsContentPackage: 0
+ XboxOneEnableGPUVariability: 0
+ XboxOneSockets: {}
+ XboxOneSplashScreen: {fileID: 0}
+ XboxOneAllowedProductIds: []
+ XboxOnePersistentLocalStorageSize: 0
+ XboxOneXTitleMemory: 8
+ xboxOneScriptCompiler: 0
+ vrEditorSettings:
+ daydream:
+ daydreamIconForeground: {fileID: 0}
+ daydreamIconBackground: {fileID: 0}
+ cloudServicesEnabled: {}
+ facebookSdkVersion: 7.9.4
+ apiCompatibilityLevel: 2
+ cloudProjectId:
+ projectName:
+ organizationId:
+ cloudEnabled: 0
+ enableNativePlatformBackendsForNewInputSystem: 0
+ disableOldInputManagerSupport: 0
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt
new file mode 100644
index 0000000000..4a9cfb61ab
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt
@@ -0,0 +1 @@
+m_EditorVersion: 2017.4.6f1
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset
new file mode 100644
index 0000000000..05daac3c49
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset
@@ -0,0 +1,191 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!47 &1
+QualitySettings:
+ m_ObjectHideFlags: 0
+ serializedVersion: 5
+ m_CurrentQuality: 5
+ m_QualitySettings:
+ - serializedVersion: 2
+ name: Very Low
+ pixelLightCount: 0
+ shadows: 0
+ shadowResolution: 0
+ shadowProjection: 1
+ shadowCascades: 1
+ shadowDistance: 15
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 0
+ blendWeights: 1
+ textureQuality: 1
+ anisotropicTextures: 0
+ antiAliasing: 0
+ softParticles: 0
+ softVegetation: 0
+ realtimeReflectionProbes: 0
+ billboardsFaceCameraPosition: 0
+ vSyncCount: 0
+ lodBias: 0.3
+ maximumLODLevel: 0
+ particleRaycastBudget: 4
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ - serializedVersion: 2
+ name: Low
+ pixelLightCount: 0
+ shadows: 0
+ shadowResolution: 0
+ shadowProjection: 1
+ shadowCascades: 1
+ shadowDistance: 20
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 0
+ blendWeights: 2
+ textureQuality: 0
+ anisotropicTextures: 0
+ antiAliasing: 0
+ softParticles: 0
+ softVegetation: 0
+ realtimeReflectionProbes: 0
+ billboardsFaceCameraPosition: 0
+ vSyncCount: 0
+ lodBias: 0.4
+ maximumLODLevel: 0
+ particleRaycastBudget: 16
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ - serializedVersion: 2
+ name: Medium
+ pixelLightCount: 1
+ shadows: 1
+ shadowResolution: 0
+ shadowProjection: 1
+ shadowCascades: 1
+ shadowDistance: 20
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 0
+ blendWeights: 2
+ textureQuality: 0
+ anisotropicTextures: 1
+ antiAliasing: 0
+ softParticles: 0
+ softVegetation: 0
+ realtimeReflectionProbes: 0
+ billboardsFaceCameraPosition: 0
+ vSyncCount: 1
+ lodBias: 0.7
+ maximumLODLevel: 0
+ particleRaycastBudget: 64
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ - serializedVersion: 2
+ name: High
+ pixelLightCount: 2
+ shadows: 2
+ shadowResolution: 1
+ shadowProjection: 1
+ shadowCascades: 2
+ shadowDistance: 40
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 1
+ blendWeights: 2
+ textureQuality: 0
+ anisotropicTextures: 1
+ antiAliasing: 0
+ softParticles: 0
+ softVegetation: 1
+ realtimeReflectionProbes: 1
+ billboardsFaceCameraPosition: 1
+ vSyncCount: 1
+ lodBias: 1
+ maximumLODLevel: 0
+ particleRaycastBudget: 256
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ - serializedVersion: 2
+ name: Very High
+ pixelLightCount: 3
+ shadows: 2
+ shadowResolution: 2
+ shadowProjection: 1
+ shadowCascades: 2
+ shadowDistance: 70
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 1
+ blendWeights: 4
+ textureQuality: 0
+ anisotropicTextures: 2
+ antiAliasing: 2
+ softParticles: 1
+ softVegetation: 1
+ realtimeReflectionProbes: 1
+ billboardsFaceCameraPosition: 1
+ vSyncCount: 1
+ lodBias: 1.5
+ maximumLODLevel: 0
+ particleRaycastBudget: 1024
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ - serializedVersion: 2
+ name: Ultra
+ pixelLightCount: 4
+ shadows: 2
+ shadowResolution: 2
+ shadowProjection: 1
+ shadowCascades: 4
+ shadowDistance: 150
+ shadowNearPlaneOffset: 3
+ shadowCascade2Split: 0.33333334
+ shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667}
+ shadowmaskMode: 1
+ blendWeights: 4
+ textureQuality: 0
+ anisotropicTextures: 2
+ antiAliasing: 2
+ softParticles: 1
+ softVegetation: 1
+ realtimeReflectionProbes: 1
+ billboardsFaceCameraPosition: 1
+ vSyncCount: 1
+ lodBias: 2
+ maximumLODLevel: 0
+ particleRaycastBudget: 4096
+ asyncUploadTimeSlice: 2
+ asyncUploadBufferSize: 4
+ resolutionScalingFixedDPIFactor: 1
+ excludedTargetPlatforms: []
+ m_PerPlatformDefaultQuality:
+ Android: 2
+ Nintendo 3DS: 5
+ Nintendo Switch: 5
+ PS4: 5
+ PSM: 5
+ PSP2: 2
+ Standalone: 5
+ Tizen: 2
+ WebGL: 3
+ WiiU: 5
+ Windows Store Apps: 5
+ XboxOne: 5
+ iPhone: 2
+ tvOS: 2
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset
new file mode 100644
index 0000000000..1c92a7840e
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset
@@ -0,0 +1,43 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!78 &1
+TagManager:
+ serializedVersion: 2
+ tags: []
+ layers:
+ - Default
+ - TransparentFX
+ - Ignore Raycast
+ -
+ - Water
+ - UI
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ m_SortingLayers:
+ - name: Default
+ uniqueID: 0
+ locked: 0
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset
new file mode 100644
index 0000000000..558a017e1f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset
@@ -0,0 +1,9 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!5 &1
+TimeManager:
+ m_ObjectHideFlags: 0
+ Fixed Timestep: 0.02
+ Maximum Allowed Timestep: 0.33333334
+ m_TimeScale: 1
+ Maximum Particle Timestep: 0.03
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset
new file mode 100644
index 0000000000..3da14d5baf
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset
@@ -0,0 +1,34 @@
+%YAML 1.1
+%TAG !u! tag:unity3d.com,2011:
+--- !u!310 &1
+UnityConnectSettings:
+ m_ObjectHideFlags: 0
+ m_Enabled: 0
+ m_TestMode: 0
+ m_TestEventUrl:
+ m_TestConfigUrl:
+ m_TestInitMode: 0
+ CrashReportingSettings:
+ m_EventUrl: https://perf-events.cloud.unity3d.com/api/events/crashes
+ m_NativeEventUrl: https://perf-events.cloud.unity3d.com/symbolicate
+ m_Enabled: 0
+ m_CaptureEditorExceptions: 1
+ UnityPurchasingSettings:
+ m_Enabled: 0
+ m_TestMode: 0
+ UnityAnalyticsSettings:
+ m_Enabled: 0
+ m_InitializeOnStartup: 1
+ m_TestMode: 0
+ m_TestEventUrl:
+ m_TestConfigUrl:
+ UnityAdsSettings:
+ m_Enabled: 0
+ m_InitializeOnStartup: 1
+ m_TestMode: 0
+ m_IosGameId:
+ m_AndroidGameId:
+ m_GameIds: {}
+ m_GameId:
+ PerformanceReportingSettings:
+ m_Enabled: 0
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
new file mode 100644
index 0000000000..0b3813fccb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
@@ -0,0 +1,24 @@
+# TF Lite Experimental Unity Plugin
+
+This directoryy contains an experimental sample Unity (2017) Plugin, based on
+the experimental TF Lite C API. The sample demonstrates running inference within
+Unity by way of a C# `Interpreter` wrapper.
+
+Note that the native TF Lite plugin(s) *must* be built before using the Unity
+Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note
+that this has only been tested on Linux; the syntax may differ on Mac/Windows):
+
+```sh
+bazel build -c opt --cxxopt=--std=c++11 \
+ //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so
+```
+
+and for Android:
+
+```sh
+bazel build -c opt --cxxopt=--std=c++11 \
+ --crosstool_top=//external:android/crosstool \
+ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
+ --cpu=armeabi-v7a \
+ //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so
+```
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json
new file mode 100644
index 0000000000..526aca6057
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json
@@ -0,0 +1,4 @@
+{
+ "dependencies": {
+ }
+}
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 0e8f4339fc..aa65ec9988 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -62,6 +62,7 @@ counterparts:
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
*as long as tensors are 2D and axis is the last dimension*
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
+* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot)
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
mode and constant_values are not used*
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -
@@ -830,6 +831,18 @@ Outputs {
}
```
+**LOGICAL_OR**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: a list of tensors.
+}
+Outputs {
+ 0: A tensor of logical_or output tensors.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index c224132cae..329c98f91e 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -170,12 +170,14 @@ cc_library(
"hashtable_lookup.cc",
"l2norm.cc",
"local_response_norm.cc",
+ "logical.cc",
"lsh_projection.cc",
"lstm.cc",
"maximum_minimum.cc",
"mfcc.cc",
"mul.cc",
"neg.cc",
+ "one_hot.cc",
"pack.cc",
"pad.cc",
"pooling.cc",
@@ -1171,6 +1173,33 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "one_hot_test",
+ size = "small",
+ srcs = ["one_hot_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "logical_test",
+ size = "small",
+ srcs = ["logical_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 45c9f65b64..63c89d1eee 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -115,10 +115,10 @@ void ClipVector(const float* vector, int v_size, float abs_limit,
}
void SymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min, float* max,
- float* scaling_factor) {
- NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min,
- max, scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor) {
+ NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
+ min_value, max_value, scaling_factor);
}
void VectorShiftLeft(float* vector, int v_size, float shift_value) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index db7926df9a..010b40b901 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -19,6 +19,10 @@ limitations under the License.
// structure.
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
#ifndef USE_NEON
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#define USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 7ead449ca8..6bd88b5596 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
+#include <algorithm>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
@@ -38,14 +39,14 @@ bool PortableIsZeroVector(const float* vector, int v_size) {
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values,
- float* __restrict__ min,
- float* __restrict__ max,
+ float* __restrict__ min_value,
+ float* __restrict__ max_value,
float* __restrict__ scaling_factor) {
auto minmax = std::minmax_element(values, values + size);
- *min = *minmax.first;
- *max = *minmax.second;
+ *min_value = *minmax.first;
+ *max_value = *minmax.second;
const int kScale = 127;
- const float range = std::max(std::abs(*min), std::abs(*max));
+ const float range = std::max(std::abs(*min_value), std::abs(*max_value));
if (range == 0) {
memset(quantized_values, 0, size * sizeof(int8_t));
*scaling_factor = 1;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index d3a4fa8507..a375aaffa6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -19,6 +19,10 @@ limitations under the License.
// structure.
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
namespace tflite {
namespace tensor_utils {
@@ -28,8 +32,8 @@ float PortableClip(float f, float abs_limit);
bool PortableIsZeroVector(const float* vector, int v_size);
void PortableSymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min,
- float* max, float* scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor);
// Multiply a matrix by a batch vector, and store results in a batch-size
// vector.
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 31a54c2b62..714613b96e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3284,7 +3284,8 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& block_shape_dims,
const int32* paddings_data,
const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
const int output_batch_size = ArraySize(output_dims, 3);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
@@ -3309,7 +3310,7 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
padding_top + input_height ||
out_w * block_shape_width + shift_w < padding_left ||
out_w * block_shape_width + shift_w >= padding_left + input_width) {
- memset(out, 0, depth * sizeof(T));
+ memset(out, pad_value, depth * sizeof(T));
} else {
const T* in =
input_data +
@@ -3325,6 +3326,17 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims,
+ paddings_data, paddings_dims, output_data, output_dims, 0);
+}
+
+template <typename T>
inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
const int32* block_shape_data,
const Dims<4>& block_shape_dims,
@@ -4243,6 +4255,38 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
}
}
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
index 4eddf7bf0a..20abcb7258 100644
--- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
@@ -43,13 +43,13 @@ bool Spectrogram::Initialize(int window_length, int step_length) {
return Initialize(window, step_length);
}
-inline int Log2Floor(uint n) {
+inline int Log2Floor(uint32_t n) {
if (n == 0) return -1;
int log = 0;
- uint value = n;
+ uint32_t value = n;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
- uint x = value >> shift;
+ uint32_t x = value >> shift;
if (x != 0) {
value = x;
log += shift;
@@ -58,7 +58,7 @@ inline int Log2Floor(uint n) {
return log;
}
-inline int Log2Ceiling(uint n) {
+inline int Log2Ceiling(uint32_t n) {
int floor = Log2Floor(n);
if (n == (n & ~(n - 1))) // zero or a power of two
return floor;
@@ -66,7 +66,7 @@ inline int Log2Ceiling(uint n) {
return floor + 1;
}
-inline uint NextPowerOfTwo(uint value) {
+inline uint32_t NextPowerOfTwo(uint32_t value) {
int exponent = Log2Ceiling(value);
// DCHECK_LT(exponent, std::numeric_limits<uint32>::digits);
return 1 << exponent;
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 82f4503127..1ff8cfe39c 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -17,6 +17,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
namespace tflite {
namespace tensor_utils {
@@ -31,8 +35,8 @@ bool IsZeroVector(const float* vector, int v_size);
// It also outputs the range (min, max) of the floating point buffer, and the
// scaling factor used to quantize the values.
void SymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min, float* max,
- float* scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor);
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
// dimension composed by input vectors independent from each other). The result
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
new file mode 100644
index 0000000000..3dc39bf79a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -0,0 +1,121 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace logical {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for logical op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteBool) {
+ context->ReportError(context, "Logical ops only support bool type.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
+ const std::function<bool(bool, bool)>& func) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (data->requires_broadcast) {
+ reference_ops::BroadcastLogical(
+ GetTensorData<bool>(input1), GetTensorDims(input1),
+ GetTensorData<bool>(input2), GetTensorDims(input2),
+ GetTensorData<bool>(output), GetTensorDims(output), func);
+ } else {
+ reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
+ GetTensorData<bool>(input2), GetTensorDims(input2),
+ GetTensorData<bool>(output), GetTensorDims(output),
+ func);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
+ const auto logical_or_func = std::logical_or<bool>();
+ return LogicalImpl(context, node, logical_or_func);
+}
+
+} // namespace
+} // namespace logical
+
+TfLiteRegistration* Register_LOGICAL_OR() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
+ logical::LogicalOrEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc
new file mode 100644
index 0000000000..382008245b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class LogicalOpModel : public SingleOpModel {
+ public:
+ LogicalOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape, BuiltinOperator op) {
+ input1_ = AddInput(TensorType_BOOL);
+ input2_ = AddInput(TensorType_BOOL);
+ output_ = AddOutput(TensorType_BOOL);
+ ConfigureBuiltinOp(op);
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_LOGICAL_OR: {
+ SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions,
+ CreateLogicalOrOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
+};
+
+TEST(LogicalTest, LogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {true, false, true, false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(LogicalTest, BroadcastLogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
new file mode 100644
index 0000000000..9ff3dca932
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -0,0 +1,199 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace one_hot {
+
+constexpr int kIndicesTensor = 0;
+constexpr int kDepthTensor = 1;
+constexpr int kOnValueTensor = 2;
+constexpr int kOffValueTensor = 3;
+constexpr int kOutputTensor = 0;
+
+// Convenience utility for destructuring a node into the appropriate tensors and
+// data for the op. Note that this destructuring is quite cheap, so we can avoid
+// allocating op-specific, persistent data on the heap.
+struct OneHotContext {
+ OneHotContext(TfLiteContext* context, TfLiteNode* node) {
+ indices = GetInput(context, node, kIndicesTensor);
+ depth = GetInput(context, node, kDepthTensor);
+ on_value = GetInput(context, node, kOnValueTensor);
+ off_value = GetInput(context, node, kOffValueTensor);
+ output = GetOutput(context, node, kOutputTensor);
+
+ const auto* params =
+ reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
+ const int indices_dims = indices->dims->size;
+ axis = (params->axis == -1) ? indices_dims : params->axis;
+ output_dims = indices_dims + 1;
+ dtype = on_value->type;
+ }
+
+ const TfLiteTensor* indices;
+ const TfLiteTensor* depth;
+ const TfLiteTensor* on_value;
+ const TfLiteTensor* off_value;
+ TfLiteTensor* output;
+ int axis;
+ int output_dims;
+ TfLiteType dtype;
+};
+
+template <typename T, typename TI>
+void OneHotComputeImpl(const OneHotContext& op_context) {
+ // prefix_dim_size == # of elements before the axis
+ // depth == # of elements per axis
+ // suffix_dim_size == # of elements after the axis
+ int prefix_dim_size = 1;
+ for (int i = 0; i < op_context.axis; ++i) {
+ prefix_dim_size *= op_context.indices->dims->data[i];
+ }
+ const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
+ const int depth = *op_context.depth->data.i32;
+
+ const T on_value = *GetTensorData<T>(op_context.on_value);
+ const T off_value = *GetTensorData<T>(op_context.off_value);
+
+ // View the indices as a matrix of size:
+ // prefix_dim_size x suffix_dim_size
+ // View the output as a matrix of size:
+ // prefix_dim_size x depth x suffix_dim_size
+ // Then the output is:
+ // output(i, j, k) == (indices(i, k) == j) ? on : off
+ T* output = GetTensorData<T>(op_context.output);
+ const TI* indices = GetTensorData<TI>(op_context.indices);
+ for (int i = 0; i < prefix_dim_size; ++i) {
+ for (int j = 0; j < depth; ++j) {
+ for (int k = 0; k < suffix_dim_size; ++k, ++output) {
+ *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
+ ? on_value
+ : off_value;
+ }
+ }
+ }
+}
+
+template <typename T>
+void OneHotCompute(const OneHotContext& op_context) {
+ if (op_context.indices->type == kTfLiteInt64) {
+ OneHotComputeImpl<T, int64_t>(op_context);
+ } else {
+ OneHotComputeImpl<T, int>(op_context);
+ }
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const OneHotContext& op_context) {
+ TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
+ for (int i = 0; i < op_context.output_dims; ++i) {
+ if (i < op_context.axis) {
+ output_size->data[i] = op_context.indices->dims->data[i];
+ } else if (i == op_context.axis) {
+ output_size->data[i] = *op_context.depth->data.i32;
+ } else {
+ output_size->data[i] = op_context.indices->dims->data[i - 1];
+ }
+ }
+ return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OneHotContext op_context{context, node};
+ switch (op_context.dtype) {
+ // TODO(b/111744875): Support uint8 and quantization.
+ case kTfLiteFloat32:
+ case kTfLiteInt16:
+ case kTfLiteInt32:
+ case kTfLiteInt64:
+ case kTfLiteBool:
+ op_context.output->type = op_context.dtype;
+ break;
+ default:
+ context->ReportError(context, "Unknown output data type: %d",
+ op_context.dtype);
+ return kTfLiteError;
+ }
+
+ TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
+ op_context.indices->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, op_context.axis >= 0 &&
+ op_context.axis < op_context.output_dims);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
+ TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
+
+ if (!IsConstantTensor(op_context.depth)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+
+ return ResizeOutputTensor(context, op_context);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OneHotContext op_context{context, node};
+
+ if (IsDynamicTensor(op_context.output)) {
+ ResizeOutputTensor(context, op_context);
+ }
+
+ switch (op_context.output->type) {
+ case kTfLiteFloat32:
+ OneHotCompute<float>(op_context);
+ break;
+ case kTfLiteInt32:
+ OneHotCompute<int>(op_context);
+ break;
+ case kTfLiteInt64:
+ OneHotCompute<int64_t>(op_context);
+ break;
+ case kTfLiteBool:
+ OneHotCompute<bool>(op_context);
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace one_hot
+
+TfLiteRegistration* Register_ONE_HOT() {
+ static TfLiteRegistration r = {
+ nullptr,
+ nullptr,
+ one_hot::Prepare,
+ one_hot::Eval,
+ };
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc
new file mode 100644
index 0000000000..6b604ec7a7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc
@@ -0,0 +1,182 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class OneHotOpModel : public SingleOpModel {
+ public:
+ OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
+ TensorType dtype, int axis = -1, T on_value = 1,
+ T off_value = 0, TensorType indices_type = TensorType_INT32) {
+ indices_ = AddInput(indices_type);
+ int depth = AddInput(TensorType_INT32);
+ int on = AddInput(dtype);
+ int off = AddInput(dtype);
+ output_ = AddOutput(dtype);
+ SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
+ CreateOneHotOptions(builder_, axis).Union());
+ BuildInterpreter({input_shape});
+
+ PopulateTensor<int>(depth, {depth_value});
+ PopulateTensor<T>(on, {on_value});
+ PopulateTensor<T>(off, {off_value});
+ }
+
+ template <typename TI>
+ void SetIndices(std::initializer_list<TI> data) {
+ PopulateTensor<TI>(indices_, data);
+ }
+
+ TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
+
+ int32_t GetOutputSize() { return GetTensorSize(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int indices_;
+ int output_;
+};
+
+TEST(OneHotOpTest, BasicFloat) {
+ const int depth = 3;
+ OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
+}
+
+TEST(OneHotOpTest, BasicInt) {
+ const int depth = 3;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+TEST(OneHotOpTest, BasicBool) {
+ const int depth = 3;
+ OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({true, false, false, false, true, false, false,
+ false, true}));
+}
+
+TEST(OneHotOpTest, SmallDepth) {
+ const int depth = 1;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
+}
+
+TEST(OneHotOpTest, BigDepth) {
+ const int depth = 4;
+ OneHotOpModel<int> model({2}, depth, TensorType_INT32);
+ model.SetIndices({0, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
+}
+
+TEST(OneHotOpTest, OnOffValues) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
+}
+
+TEST(OneHotOpTest, ZeroAxis) {
+ const int depth = 3;
+ const int axis = 0;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
+}
+
+TEST(OneHotOpTest, MultiDimensionalIndices) {
+ const int depth = 3;
+ const int axis = -1;
+ const float on = 2;
+ const float off = 0;
+ OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
+ model.SetIndices({0, 2, 1, -1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
+}
+
+TEST(OneHotOpTest, Int64Indices) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 1;
+ const int off = 0;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
+ TensorType_INT64);
+ std::initializer_list<int64_t> indices = {0, 1, 2};
+ model.SetIndices(indices);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0b70bed308..e632728841 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -107,6 +107,8 @@ TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
+TfLiteRegistration* Register_ONE_HOT();
+TfLiteRegistration* Register_LOGICAL_OR();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -197,6 +199,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_POW, Register_POW());
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
+ AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
+ AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 99ecc16093..49ba0571e2 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -37,10 +37,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
// special -1 value, meaning it will be calculated automatically based on the
// input. Here we calculate what that dimension should be so that the number
// of output elements in the same as the number of input elements.
- int num_input_elements = 1;
- for (int i = 0; i < NumDimensions(input); ++i) {
- num_input_elements *= SizeOfDimension(input, i);
- }
+ int num_input_elements = NumElements(input);
int num_output_elements = 1;
int stretch_dim = -1;
@@ -96,9 +93,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// The function is returned above this line if the shape tensor is usable.
// Now fallback to the shape parameter in `TfLiteReshapeParams`.
-
- TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
- for (int i = 0; i < params->num_dimensions; ++i) {
+ int num_dimensions = params->num_dimensions;
+ if (num_dimensions == 1 && params->shape[0] == 0) {
+ // Legacy tflite models use a shape parameter of [0] to indicate scalars,
+ // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
+ // toco conversion.
+ num_dimensions = 0;
+ }
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
+ for (int i = 0; i < num_dimensions; ++i) {
output_shape->data[i] = params->shape[i];
}
return ResizeOutput(context, node, output_shape);
diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc
index aecbd0399f..52d71350d3 100644
--- a/tensorflow/contrib/lite/kernels/reshape_test.cc
+++ b/tensorflow/contrib/lite/kernels/reshape_test.cc
@@ -22,18 +22,27 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
class ReshapeOpModel : public SingleOpModel {
public:
ReshapeOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> new_shape) {
+ std::initializer_list<int> new_shape,
+ bool use_shape_input_tensor = false) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
+ int shape_input_tensor =
+ use_shape_input_tensor ? AddInput(TensorType_INT32) : -1;
SetBuiltinOp(
BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
.Union());
- BuildInterpreter({input_shape});
+ if (use_shape_input_tensor) {
+ BuildInterpreter({input_shape, GetShape(shape_input_tensor)});
+ PopulateTensor<int>(shape_input_tensor, new_shape);
+ } else {
+ BuildInterpreter({input_shape});
+ }
}
void SetInput(std::initializer_list<float> data) {
@@ -71,6 +80,14 @@ TEST(ReshapeOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
+TEST(ReshapeOpTest, ShapeTensorInput) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
+}
+
TEST(ReshapeOpTest, WithStretchDimension) {
ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
@@ -79,6 +96,22 @@ TEST(ReshapeOpTest, WithStretchDimension) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4}));
}
+TEST(ReshapeOpTest, ScalarOutput) {
+ ReshapeOpModel m({1}, {});
+ m.SetInput({3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+}
+
+TEST(ReshapeOpTest, LegacyScalarOutput) {
+ ReshapeOpModel m({1}, {0});
+ m.SetInput({3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
index 10caffea03..f4289105f7 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -247,7 +247,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
3, 6, //
9, 12, //
4, 10, //
- 10, 16 //
+ 12, 16 //
});
m.SetSize({3, 3});
m.Invoke();
@@ -256,8 +256,8 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
7, 9, 10, //
9, 11, 12, //
4, 8, 10, //
- 8, 12, 14, //
- 10, 13, 16, //
+ 9, 12, 14, //
+ 12, 14, 16, //
})));
ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
@@ -265,7 +265,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
3, 6, //
9, 12, //
4, 10, //
- 10, 16 //
+ 12, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
@@ -273,35 +273,35 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
7, 9, 10, //
9, 11, 12, //
4, 8, 10, //
- 8, 12, 14, //
- 10, 13, 16, //
+ 9, 12, 14, //
+ 12, 14, 16, //
})));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}});
m.SetInput<uint8>({
- 3, 4, 6, 10, //
- 9, 10, 12, 16, //
+ 3, 4, 6, 10, //
+ 10, 12, 14, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 13, 12, 16, //
+ 3, 4, 5, 8, 6, 10, //
+ 7, 9, 10, 12, 11, 14, //
+ 10, 12, 12, 14, 14, 16, //
})));
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
const_m.SetInput<uint8>({
- 3, 4, 6, 10, //
- 9, 10, 12, 16, //
+ 3, 4, 6, 10, //
+ 10, 12, 14, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 13, 12, 16, //
+ 3, 4, 5, 8, 6, 10, //
+ 7, 9, 10, 12, 11, 14, //
+ 10, 12, 12, 14, 14, 16, //
})));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index c9269599e5..03079f1c3b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
-#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \
+#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
GetTensorDims(op_context.input), \
GetTensorData<int32_t>(op_context.block_shape), \
@@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int32_t>(op_context.paddings), \
GetTensorDims(op_context.paddings), \
GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorDims(op_context.output), pad_value)
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
}
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
+ op_context.output->params.zero_point);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
+ op_context.output->params.zero_point);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
}
break;
case kTfLiteInt64:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
}
break;
default:
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
index 92a4a037d5..5756573629 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
@@ -23,6 +23,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::Matcher;
class SpaceToBatchNDOpModel : public SingleOpModel {
public:
@@ -30,6 +31,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
PopulateTensor<float>(input_, data);
}
+ void SetQuantizedInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
void SetBlockShape(std::initializer_list<int> data) {
PopulateTensor<int>(block_shape_, data);
}
@@ -41,6 +46,11 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
protected:
int input_;
int block_shape_;
@@ -56,18 +66,19 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
// m.Invoke();
class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
public:
- SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
+ SpaceToBatchNDOpConstModel(const TensorData& input,
std::initializer_list<int> block_shape,
- std::initializer_list<int> paddings) {
- input_ = AddInput(TensorType_FLOAT32);
+ std::initializer_list<int> paddings,
+ const TensorData& output) {
+ input_ = AddInput(input);
block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
BuiltinOptions_SpaceToBatchNDOptions,
CreateSpaceToBatchNDOptions(builder_).Union());
- BuildInterpreter({input_shape});
+ BuildInterpreter({input.shape});
}
};
@@ -81,26 +92,30 @@ class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
// m.Invoke();
class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel {
public:
- SpaceToBatchNDOpDynamicModel(std::initializer_list<int> input_shape) {
- input_ = AddInput(TensorType_FLOAT32);
+ SpaceToBatchNDOpDynamicModel(const TensorData& input,
+ const TensorData& output) {
+ input_ = AddInput(input);
block_shape_ = AddInput(TensorType_INT32);
paddings_ = AddInput(TensorType_INT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
BuiltinOptions_SpaceToBatchNDOptions,
CreateSpaceToBatchNDOptions(builder_).Union());
- BuildInterpreter({input_shape, {2}, {2, 2}});
+ BuildInterpreter({input.shape, {2}, {2, 2}});
}
};
TEST(SpaceToBatchNDOpTest, InvalidShapeTest) {
- EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}),
- "Cannot allocate tensors");
+ EXPECT_DEATH(
+ SpaceToBatchNDOpConstModel({TensorType_FLOAT32, {1, 3, 3, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32}),
+ "Cannot allocate tensors");
}
TEST(SpaceToBatchNDOpTest, SimpleConstTest) {
- SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
@@ -109,7 +124,8 @@ TEST(SpaceToBatchNDOpTest, SimpleConstTest) {
}
TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 4, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetPaddings({0, 0, 0, 0});
@@ -120,7 +136,8 @@ TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) {
- SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
@@ -129,7 +146,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) {
}
TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetPaddings({0, 0, 0, 0});
@@ -140,7 +158,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) {
- SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, {3, 2},
+ {1, 0, 2, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
@@ -151,7 +170,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) {
}
TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 5, 2, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 0, 2, 0});
@@ -164,7 +184,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) {
- SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, {3, 2},
+ {1, 1, 2, 4}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
@@ -176,7 +197,8 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) {
}
TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 2, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 1, 2, 4});
@@ -189,6 +211,88 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
}));
}
+class QuantizedSpaceToBatchNDOpTest : public ::testing::Test {
+ protected:
+ std::vector<Matcher<float>> DequantizedArrayNear(
+ const std::vector<float>& values, const float min, const float max) {
+ const float quantization_tolerance = (max - min) / 255.0;
+ return ArrayFloatNear(values, quantization_tolerance);
+ }
+};
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) {
+ // The test_util and actual quantization code currently ensure that the range
+ // must include zero, but if that ever changes, this test will catch it.
+ EXPECT_DEATH(SpaceToBatchNDOpConstModel m(
+ {TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2},
+ {0, 0, 1, 1, 1, 1, 0, 0}, {TensorType_UINT8, {}, 1.0, 2.0}),
+ ".*Check failed: f_min <= 0.*");
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) {
+ SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
+ {3, 2}, {1, 0, 2, 0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
+ 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
+ SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
+ m.SetBlockShape({3, 2});
+ m.SetPaddings({1, 0, 2, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
+ 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) {
+ SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
+ {3, 2}, {1, 1, 2, 4},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {
+ 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
+ 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+ 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0,
+ },
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
+ SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
+ m.SetBlockShape({3, 2});
+ m.SetPaddings({1, 1, 2, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {
+ 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
+ 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+ 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0,
+ },
+ -1.0, 1.0)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index c6869feb16..5814cddc5b 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -730,6 +730,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = static_cast<void*>(params);
break;
}
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 551e8ed320..1c06b29deb 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -623,6 +623,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_FAKE_QUANT:
case tflite::BuiltinOperator_PACK:
case tflite::BuiltinOperator_LOGICAL_OR:
+ case tflite::BuiltinOperator_ONE_HOT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc
index 446660bb74..875ddb02bc 100644
--- a/tensorflow/contrib/lite/profiling/time.cc
+++ b/tensorflow/contrib/lite/profiling/time.cc
@@ -14,16 +14,34 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/profiling/time.h"
+#if defined(_MSC_VER)
+#include <chrono> // NOLINT(build/c++11)
+#else
#include <sys/time.h>
+#endif
namespace tflite {
namespace profiling {
namespace time {
+
+#if defined(_MSC_VER)
+
+uint64_t NowMicros() {
+ return std::chrono::duration_cast<std::chrono::microseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+}
+
+#else
+
uint64_t NowMicros() {
struct timeval tv;
gettimeofday(&tv, nullptr);
return static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
}
+
+#endif // defined(_MSC_VER)
+
} // namespace time
} // namespace profiling
} // namespace tflite
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index a285bf9919..8ed98ddaf4 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -166,6 +166,7 @@ enum BuiltinOperator : byte {
REDUCE_MAX = 82,
PACK = 83,
LOGICAL_OR = 84,
+ ONE_HOT = 85,
}
// Options for the builtin operators.
@@ -230,6 +231,7 @@ union BuiltinOptions {
FakeQuantOptions,
PackOptions,
LogicalOrOptions,
+ OneHotOptions,
}
enum Padding : byte { SAME, VALID }
@@ -549,6 +551,10 @@ table PackOptions {
table LogicalOrOptions {
}
+table OneHotOptions {
+ axis:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8c1d6d6a36..4402f89b85 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -211,6 +211,9 @@ struct PackOptionsT;
struct LogicalOrOptions;
struct LogicalOrOptionsT;
+struct OneHotOptions;
+struct OneHotOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -361,11 +364,12 @@ enum BuiltinOperator {
BuiltinOperator_REDUCE_MAX = 82,
BuiltinOperator_PACK = 83,
BuiltinOperator_LOGICAL_OR = 84,
+ BuiltinOperator_ONE_HOT = 85,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR
+ BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
BuiltinOperator_REDUCE_PROD,
BuiltinOperator_REDUCE_MAX,
BuiltinOperator_PACK,
- BuiltinOperator_LOGICAL_OR
+ BuiltinOperator_LOGICAL_OR,
+ BuiltinOperator_ONE_HOT
};
return values;
}
@@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() {
"REDUCE_MAX",
"PACK",
"LOGICAL_OR",
+ "ONE_HOT",
nullptr
};
return names;
@@ -614,11 +620,12 @@ enum BuiltinOptions {
BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_PackOptions = 59,
BuiltinOptions_LogicalOrOptions = 60,
+ BuiltinOptions_OneHotOptions = 61,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions
+ BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
BuiltinOptions_ArgMinOptions,
BuiltinOptions_FakeQuantOptions,
BuiltinOptions_PackOptions,
- BuiltinOptions_LogicalOrOptions
+ BuiltinOptions_LogicalOrOptions,
+ BuiltinOptions_OneHotOptions
};
return values;
}
@@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() {
"FakeQuantOptions",
"PackOptions",
"LogicalOrOptions",
+ "OneHotOptions",
nullptr
};
return names;
@@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
};
+template<> struct BuiltinOptionsTraits<OneHotOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LogicalOrOptions ?
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
}
+ OneHotOptionsT *AsOneHotOptions() {
+ return type == BuiltinOptions_OneHotOptions ?
+ reinterpret_cast<OneHotOptionsT *>(value) : nullptr;
+ }
+ const OneHotOptionsT *AsOneHotOptions() const {
+ return type == BuiltinOptions_OneHotOptions ?
+ reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct OneHotOptionsT : public flatbuffers::NativeTable {
+ typedef OneHotOptions TableType;
+ int32_t axis;
+ OneHotOptionsT()
+ : axis(0) {
+ }
+};
+
+struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef OneHotOptionsT NativeTableType;
+ enum {
+ VT_AXIS = 4
+ };
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+ OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct OneHotOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0);
+ }
+ explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &);
+ flatbuffers::Offset<OneHotOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<OneHotOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t axis = 0) {
+ OneHotOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
}
+ const OneHotOptions *builtin_options_as_OneHotOptions() const {
+ return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr
return builtin_options_as_LogicalOrOptions();
}
+template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const {
+ return builtin_options_as_OneHotOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers:
_fbb);
}
+inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new OneHotOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = axis(); _o->axis = _e; };
+}
+
+inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateOneHotOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _axis = _o->axis;
+ return tflite::CreateOneHotOptions(
+ _fbb,
+ _axis);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
+ return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_OneHotOptions: {
+ value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<OneHotOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc
index 24593d2a67..cd0f1f7c17 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.cc
+++ b/tensorflow/contrib/lite/simple_memory_arena.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/simple_memory_arena.h"
+#include <algorithm>
#include <cstring>
#include <limits>
#include <vector>
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 41ece94237..6d03c0fd9e 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -104,6 +104,8 @@ KNOWN_BUGS = {
r"div.*int32": "72051395",
# No support for SplitV
r"split.*num_or_size_splits=\[2,2\]": "73377559",
+ # Scalar constants don't work.
+ r"constant.*shape=\[\]": "109811500",
}
@@ -229,6 +231,7 @@ _TF_TYPE_INFO = {
tf.int32: (np.int32, "INT32"),
tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
tf.int64: (np.int64, "INT64"),
+ tf.bool: (np.bool, "BOOL"),
}
@@ -242,7 +245,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
elif dtype in (tf.int32, tf.uint8, tf.int64):
value = np.random.randint(min_value, max_value+1, shape)
- return value.astype(dtype)
+ elif dtype == tf.bool:
+ value = np.random.choice([True, False], size=shape)
+ return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
+ dtype)
def create_scalar_data(dtype, min_value=-100, max_value=100):
@@ -479,7 +485,7 @@ def make_zip_of_tests(zip_path,
else report_lib.FAILED)
report["toco_log"] = toco_log
- if FLAGS.save_graphdefs:
+ if True or FLAGS.save_graphdefs:
archive.writestr(label + ".pbtxt",
text_format.MessageToString(graph_def),
zipfile.ZIP_DEFLATED)
@@ -734,21 +740,22 @@ def make_constant_tests(zip_path):
test_parameters = [{
"dtype": [tf.float32, tf.int32],
- "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
+ "input_shape": [[], [1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
}]
def build_graph(parameters):
- # Since Toco & Tflite can't have a single constant op in the entire graph,
- # this test adds a zero tensor with a constant op tensor.
- input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
- shape=parameters["input_shape"])
- out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1
- return [input1], [out]
+ dummy_input = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input1",
+ shape=parameters["input_shape"])
+ out = tf.constant(
+ create_tensor_data(parameters["dtype"], parameters["input_shape"]))
+ return [dummy_input], [out]
def build_inputs(parameters, sess, inputs, outputs):
- input1 = np.zeros(parameters["input_shape"],
- dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
- return [input1], sess.run(outputs, feed_dict={inputs[0]: input1})
+ dummy_input = np.zeros(
+ parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
+ return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input})
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -1608,6 +1615,11 @@ def make_reshape_tests(zip_path):
"input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
"output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
"constant_shape": [True, False],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[1]],
+ "output_shape": [[]],
+ "constant_shape": [True, False],
}]
def build_graph(parameters):
@@ -1665,6 +1677,65 @@ def make_shape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_one_hot_tests(zip_path):
+ """Make a set of tests to do one_hot."""
+
+ test_parameters = [{
+ "indices_type": [tf.int32, tf.int64],
+ "indices_shape": [[3], [4, 4], [1, 5], [5, 1]],
+ "axis": [0, 1],
+ "dtype": [tf.int32, tf.int64, tf.float32],
+ "provide_optional_inputs": [True, False],
+ }]
+
+ def build_graph(parameters):
+ indices = tf.placeholder(
+ dtype=parameters["indices_type"],
+ name="indices",
+ shape=parameters["indices_shape"])
+ depth = tf.placeholder(dtype=tf.int32, name="depth", shape=())
+
+ if not parameters["provide_optional_inputs"]:
+ out = tf.one_hot(indices=indices, depth=depth)
+ return [indices, depth], [out]
+
+ on_value = tf.placeholder(
+ dtype=parameters["dtype"], name="on_value", shape=())
+ off_value = tf.placeholder(
+ dtype=parameters["dtype"], name="off_value", shape=())
+ out = tf.one_hot(
+ indices=indices,
+ depth=depth,
+ on_value=on_value,
+ off_value=off_value,
+ axis=parameters["axis"],
+ dtype=parameters["dtype"])
+ return [indices, depth, on_value, off_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = [
+ create_tensor_data(
+ parameters["indices_type"],
+ shape=parameters["indices_shape"],
+ min_value=-1,
+ max_value=10),
+ create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10),
+ ]
+
+ if parameters["provide_optional_inputs"]:
+ input_values.append(
+ create_tensor_data(
+ parameters["dtype"], shape=None, min_value=1, max_value=10))
+ input_values.append(
+ create_tensor_data(
+ parameters["dtype"], shape=None, min_value=-1, max_value=0))
+
+ return input_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, input_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_resize_bilinear_tests(zip_path):
"""Make a set of tests to do resize_bilinear."""
@@ -2918,6 +2989,35 @@ def make_pack_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_logical_or_tests(zip_path):
+ """Make a set of tests to do logical_or."""
+
+ test_parameters = [{
+ "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the logical_or op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1])
+ out = tf.logical_or(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(tf.bool,
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(tf.bool,
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 770092e12c..106cbc1b8e 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -226,7 +226,8 @@ TEST_P(OpsTest, RunZipTests) {
string message = test_driver.GetErrorMessage();
if (bug_number.empty()) {
if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
- EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
+ EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter"))
+ << message;
} else {
EXPECT_TRUE(result) << message;
}
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
index d6a6ff8f56..ec435ca60d 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.cc
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -179,7 +179,7 @@ void TfDriver::Invoke() {
auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
output_names_, {}, &output_tensors_);
if (!status.ok()) {
- Invalidate("Failed to invoke interpreter");
+ Invalidate("Failed to run input data on graph");
}
}
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b79bb300f0..378212cb74 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1316,6 +1316,20 @@ void ConvertResizeBilinearOperator(const Model& model,
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
}
+void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
+ onehot_op->set_op("OneHot");
+ onehot_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 4);
+ for (const auto& input : src_op.inputs) {
+ *onehot_op->add_input() = input;
+ }
+ (*onehot_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+ (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
+}
+
namespace {
// TODO(aselle): Remove when available in absl
absl::string_view FindLongestCommonPrefix(absl::string_view a,
@@ -1911,6 +1925,21 @@ void ConvertLogicalNotOperator(const Model& model,
*logical_op->add_input() = src_op.inputs[0];
}
+void ConvertLogicalOrOperator(const Model& model,
+ const LogicalOrOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
+ logical_or_op->set_op(op_name);
+ logical_or_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *logical_or_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*logical_or_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2158,6 +2187,13 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertLogicalNotOperator(model,
static_cast<const LogicalNotOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kOneHot) {
+ ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalOr) {
+ ConvertLogicalOrOperator(model,
+ static_cast<const LogicalOrOperator&>(src_op),
+ "LogicalOr", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index 75642bbc37..c13fc0de75 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -181,7 +181,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
// future without worrying.
static constexpr int kMinDistanceBetweenBadValues = 16;
if (distance < kMinDistanceBetweenBadValues) {
- if (allow_nudging_weights()) {
+ if (allow_nudging_weights() || has_default_ranges_flag()) {
buffer_data[i] = 1;
changed = true;
continue;
@@ -200,6 +200,15 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
}
if (changed) {
+ if (has_default_ranges_flag()) {
+ std::cerr
+ << "Since the specified values of --default_ranges_min and "
+ "--default_ranges_max result in values incompatible with TFLite's "
+ "fast int8 kernels, "
+ "--allow_nudging_weights_to_use_fast_gemm_kernel "
+ "has been enabled. This may affect the accuracy of the model."
+ << std::endl;
+ }
AddMessageF("Tweaked weights values for %s", LogName(op));
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index b7634e28c6..8d9a4c4700 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -262,8 +262,12 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool allow_nudging_weights() const { return allow_nudging_weights_; }
void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; }
+ bool has_default_ranges_flag() const { return has_default_ranges_flag_; }
+ void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; }
+
private:
bool allow_nudging_weights_ = false;
+ bool has_default_ranges_flag_ = false;
};
#undef DECLARE_GRAPH_TRANSFORMATION
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 9c22497d5e..f033ee013e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -65,6 +65,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kAny:
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
+ case OperatorType::kLogicalOr:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
@@ -141,7 +142,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 2);
CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32);
- model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type;
+ model->GetArray(op->outputs[0]).data_type =
+ model->GetArray(op->inputs[0]).data_type;
model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
break;
}
@@ -201,6 +203,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
+ case OperatorType::kOneHot: {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+ const ArrayDataType on_value_type =
+ model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type;
+ const ArrayDataType off_value_type =
+ model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT])
+ .data_type;
+ CHECK(on_value_type == off_value_type);
+ model->GetArray(op->outputs[0]).data_type = on_value_type;
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a03b589bae..3c9379fd87 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) {
}
}
+void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ // Yield until indices dims have been resolved.
+ const auto& indices_array =
+ model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
+ if (!indices_array.has_shape()) {
+ return;
+ }
+
+ // Yield until depth is constant and dims have been resolved.
+ if (!IsConstantParameterArray(*model,
+ op->inputs[OneHotOperator::DEPTH_INPUT])) {
+ return;
+ }
+ const auto& depth_array =
+ model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
+ if (!depth_array.has_shape()) {
+ return;
+ }
+
+ CHECK(depth_array.data_type == ArrayDataType::kInt32)
+ << "Depth array must be int32.";
+ CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
+ << "Depth array must be scalar.";
+
+ const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
+ CHECK_GE(depth, 0) << "Depth must be non-negative.";
+
+ const int indices_dims = indices_array.shape().dimensions_count();
+ const int output_dims = indices_dims + 1;
+ const int axis = op->axis == -1 ? indices_dims : op->axis;
+ CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->resize(output_dims);
+ for (int i = 0; i < output_dims; ++i) {
+ int dim = 0;
+ if (i < axis) {
+ dim = indices_array.shape().dims(i);
+ } else if (i == axis) {
+ dim = depth;
+ } else {
+ dim = indices_array.shape().dims(i - 1);
+ }
+ (*mutable_dims)[i] = dim;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1618,6 +1673,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSin:
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
+ case OperatorType::kLogicalOr:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
@@ -1825,6 +1881,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kAny:
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
+ case OperatorType::kOneHot:
+ ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index f36f720857..9a3db5c888 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -215,7 +215,7 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -253,7 +253,7 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -290,7 +290,7 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -326,7 +326,7 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -363,7 +363,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -409,7 +409,7 @@ tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
- CHECK_LE(input_shape.dim_size(), 4);
+ CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
@@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertOneHotOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "OneHot");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
+
+ const auto dtype = GetDataTypeAttr(node, "T");
+ // TODO(b/111744875): Support DT_UINT8 and quantization.
+ CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
+ dtype == DT_BOOL);
+
+ auto op = absl::make_unique<OneHotOperator>();
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -1893,9 +1914,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
{"Log", ConvertSimpleOperator<LogOperator, 1>},
- {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
+ {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2>},
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
+ {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
{"MatMul", ConvertMatMulOperator},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
@@ -1909,6 +1931,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
{"NoOp", ConvertNoOpOperator},
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
+ {"OneHot", ConvertOneHotOperator},
{"Pack", ConvertPackOperator},
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6459dccf64..7d0dbfcc05 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -64,6 +64,7 @@ enum class OperatorType : uint8 {
kMaxPool,
kFakeQuant,
kMul,
+ kOneHot,
kRandomUniform,
kRange,
kRank,
@@ -146,6 +147,7 @@ enum class OperatorType : uint8 {
kAny,
kLogicalAnd,
kLogicalNot,
+ kLogicalOr,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1768,6 +1770,38 @@ struct LogicalNotOperator : Operator {
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
};
+// OneHot operator:
+//
+// Inputs:
+// Inputs[0]: required: indices.
+// Inputs[1]: required: depth.
+// Inputs[2]: required: on_value.
+// Inputs[3]: required: off_value.
+//
+// TensorFlow equivalent: OneHot.
+struct OneHotOperator : Operator {
+ enum Inputs {
+ INDICES_INPUT = 0,
+ DEPTH_INPUT = 1,
+ ON_VALUE_INPUT = 2,
+ OFF_VALUE_INPUT = 3,
+ };
+
+ OneHotOperator() : Operator(OperatorType::kOneHot) {}
+ int axis = -1;
+};
+
+// LogicalOr operator:
+//
+// Inputs:
+// Inputs[0]: required: A Bool tensor.
+// Inputs[1]: required: A Bool tensor.
+//
+// TensorFlow equivalent: LogicalOr.
+struct LogicalOrOperator : Operator {
+ LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 4b2ef756cc..9380168f30 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1053,6 +1053,23 @@ class Shape
int GetVersion(const Operator& op) const override { return 1; }
};
+class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
+ ::tflite::BuiltinOptions_OneHotOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateOneHotOptions(*builder, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kFakeQuant));
ops.emplace_back(
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(
+ new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
// Custom Operators.
ops.emplace_back(
@@ -1331,6 +1350,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow));
+ ops.emplace_back(new SimpleOperator<LogicalOrOperator>(
+ "LOGICAL_OR", OperatorType::kLogicalOr));
// Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 44de6fbf64..384f7c118d 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -127,6 +127,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
+ CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
+ OperatorType::kLogicalOr);
}
TEST_F(OperatorTest, BuiltinAdd) {
@@ -462,6 +464,14 @@ TEST_F(OperatorTest, BuiltinPack) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, BuiltinOneHot) {
+ OneHotOperator op;
+ op.axis = 2;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ONE_HOT", OperatorType::kOneHot), op);
+ EXPECT_EQ(op.axis, output_toco_op->axis);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index aa7f6996eb..fcd3cbab07 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -309,8 +309,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// HardcodeMinMax to move changes through the graph as we make changes.
auto propagate_default_min_max =
absl::make_unique<PropagateDefaultMinMax>();
- if (toco_flags.has_default_ranges_min() &&
- toco_flags.has_default_ranges_max()) {
+ bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
+ toco_flags.has_default_ranges_max());
+ if (has_default_ranges_flag) {
propagate_default_min_max->DefineTypeRange(
ArrayDataType::kUint8, toco_flags.default_ranges_min(),
toco_flags.default_ranges_max());
@@ -335,6 +336,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
new EnsureUint8WeightsSafeForFastInt8Kernels;
ensure_safe_for_int8_kernels->set_allow_nudging_weights(
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
+ ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
+ has_default_ranges_flag);
RunGraphTransformations(model, "quantization graph transformations",
{
new RemoveTrivialQuantizedActivationFunc,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 98e416b76e..68155c7329 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
+ HANDLE_OPERATORTYPENAME_CASE(OneHot)
HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)
@@ -402,6 +403,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Any)
HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -1617,11 +1619,12 @@ void CheckIsReadyForQuantization(const Model& model) {
<< "Array " << input << ", which is an input to the "
<< HelpfulOperatorTypeName(*op) << " operator producing the output "
<< "array " << op->outputs[0] << ", is lacking min/max data, "
- << "which is necessary for quantization. Either target a "
- << "non-quantized output format, or change the input graph to "
- << "contain min/max information, or pass --default_ranges_min= and "
- << "--default_ranges_max= if you do not care about the accuracy of "
- << "results.";
+ << "which is necessary for quantization. If accuracy matters, either "
+ << "target a non-quantized output format, or run quantized training "
+ << "with your model from a floating point checkpoint to change the "
+ << "input graph to contain min/max information. If you don't care "
+ << "about accuracy, you can pass --default_ranges_min= and "
+ << "--default_ranges_max= for easy experimentation.";
}
}
}
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
index ec033c4a01..a44bfd1bfd 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
@@ -38,12 +38,8 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBasic(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- # Note that we name the variables uniquely here since the variables don't
- # seem to be getting deleted at the end of the loop.
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
- name='a_%d' % i)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
- name='b_%d' % i)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
def loss():
return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
# Note that for eager execution, minimize expects a function instead of a
@@ -131,12 +127,8 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNoGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- # Note that we name the variables uniquely here since the variables don't
- # seem to be getting deleted at the end of the loop.
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
- name='a%d' % i)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
- name='b%d' % i)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
# pylint: disable=cell-var-from-loop
def loss():
return 5 * var0
@@ -149,12 +141,8 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_Minimize(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- # Note that we name the variables uniquely here since the variables don't
- # seem to be getting deleted at the end of the loop.
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
- name='a_%d' % i)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
- name='b_%d' % i)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
def loss():
return constant_op.constant(5.0)
sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
@@ -165,12 +153,8 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_ApplyGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- # Note that we name the variables uniquely here since the variables don't
- # seem to be getting deleted at the end of the loop.
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
- name='a_%d' % i)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
- name='b_%d' % i)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
with self.assertRaisesRegexp(ValueError,
'No gradients provided for any variable'):
@@ -179,12 +163,8 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testGradientsAsVariables(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- # Note that we name the variables uniquely here since the variables don't
- # seem to be getting deleted at the end of the loop.
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
- name='a%d' % i)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
- name='b%d' % i)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
def loss():
return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
index fa16b82ab6..4f289e0c85 100644
--- a/tensorflow/contrib/recurrent/python/ops/recurrent.py
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -79,7 +79,7 @@ def _Index(struct, index):
"""
index = ops.convert_to_tensor(index)
index.get_shape().assert_has_rank(0)
- return nest.map_structure(lambda x: x[index], struct)
+ return nest.map_structure(lambda x: array_ops.gather(x, index), struct)
def _Update(struct_acc, struct_x, t):
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 26fd4e2023..fbb50befdf 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -93,3 +93,32 @@ py_test(
"//tensorflow/python/saved_model:utils",
],
)
+
+py_library(
+ name = "keras_saved_model",
+ srcs = ["python/saved_model/keras_saved_model.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/saved_model:constants",
+ ],
+)
+
+py_test(
+ name = "keras_saved_model_test",
+ size = "small",
+ srcs = ["python/saved_model/keras_saved_model_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":saved_model_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:training",
+ "//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index b4f27a055d..95e1a8967b 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -24,11 +24,12 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
+from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
# pylint: enable=unused-import,widcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ["get_signature_def_by_key"]
+_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
index 7b91622b61..e3b76bb6f3 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -24,5 +24,6 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
+from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
new file mode 100644
index 0000000000..e2a969f053
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Utility functions to save/load keras Model to/from SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras.models import model_from_json
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.saved_model import constants
+from tensorflow.python.util import compat
+
+
+def save_model(model, saved_model_path):
+ """Save a `tf.keras.Model` into Tensorflow SavedModel format.
+
+ `save_model` generates such files/folders under the `saved_model_path` folder:
+ 1) an asset folder containing the json string of the model's
+ configuration(topology).
+ 2) a checkpoint containing the model weights.
+
+ Note that subclassed models can not be saved via this function, unless you
+ provide an implementation for get_config() and from_config().
+ Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
+ saved to checkpoints. Use optimizers from `tf.train`.
+
+ Args:
+ model: A `tf.keras.Model` to be saved.
+ saved_model_path: a string specifying the path to the SavedModel directory.
+
+ Raises:
+ NotImplementedError: If the passed in model is a subclassed model.
+ """
+ if not model._is_graph_network:
+ raise NotImplementedError
+
+ # save model configuration as a json string under assets folder.
+ model_json = model.to_json()
+ assets_destination_dir = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.ASSETS_DIRECTORY))
+
+ if not file_io.file_exists(assets_destination_dir):
+ file_io.recursive_create_dir(assets_destination_dir)
+
+ model_json_filepath = os.path.join(
+ compat.as_bytes(assets_destination_dir),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ file_io.write_string_to_file(model_json_filepath, model_json)
+
+ # save model weights in checkpoint format.
+ checkpoint_destination_dir = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY))
+
+ if not file_io.file_exists(checkpoint_destination_dir):
+ file_io.recursive_create_dir(checkpoint_destination_dir)
+
+ checkpoint_prefix = os.path.join(
+ compat.as_text(checkpoint_destination_dir),
+ compat.as_text(constants.VARIABLES_FILENAME))
+ model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+
+
+def load_model(saved_model_path):
+ """Load a keras.Model from SavedModel.
+
+ load_model reinstantiates model state by:
+ 1) loading model topology from json (this will eventually come
+ from metagraph).
+ 2) loading model weights from checkpoint.
+
+ Args:
+ saved_model_path: a string specifying the path to an existing SavedModel.
+
+ Returns:
+ a keras.Model instance.
+ """
+ # restore model topology from json string
+ model_json_filepath = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ model_json = file_io.read_file_to_string(model_json_filepath)
+ model = model_from_json(model_json)
+
+ # restore model weights
+ checkpoint_prefix = os.path.join(
+ compat.as_text(saved_model_path),
+ compat.as_text(constants.VARIABLES_DIRECTORY),
+ compat.as_text(constants.VARIABLES_FILENAME))
+ model.load_weights(checkpoint_prefix)
+ return model
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
new file mode 100644
index 0000000000..107ae1b07b
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -0,0 +1,201 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Tests for saving/loading function for keras Model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import numpy as np
+
+from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
+from tensorflow.python import keras
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.platform import test
+from tensorflow.python.training import training as training_module
+
+
+class TestModelSavingandLoading(test.TestCase):
+
+ def test_saving_sequential_model(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_sequential_model_without_compile(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+
+ x = np.random.random((1, 3))
+ ref_y = model.predict(x)
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ def test_saving_functional_model(self):
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_functional_model_without_compile(self):
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_with_tf_optimizer(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ loaded_model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ # test that new updates are the same with both models
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+
+ ref_loss = model.train_on_batch(x, y)
+ loss = loaded_model.train_on_batch(x, y)
+ self.assertAllClose(ref_loss, loss, atol=1e-05)
+
+ ref_y = model.predict(x)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ # test saving/loading again
+ keras_saved_model.save_model(loaded_model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ def test_saving_subclassed_model_raise_error(self):
+ # For now, saving subclassed model should raise an error. It should be
+ # avoided later with loading from SavedModel.pb.
+
+ class SubclassedModel(training.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.layer1 = keras.layers.Dense(3)
+ self.layer2 = keras.layers.Dense(1)
+
+ def call(self, inp):
+ return self.layer2(self.layer1(inp))
+
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ with self.assertRaises(NotImplementedError):
+ keras_saved_model.save_model(model, temp_saved_model)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 5889fd5aaf..46f3c36e3d 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -122,7 +122,6 @@ tf_cuda_library(
tf_gen_op_wrapper_py(
name = "trt_engine_op",
- gen_locally = True,
deps = [
":trt_engine_op_op_lib",
":trt_logging",
@@ -394,6 +393,7 @@ cuda_py_tests(
# "test/unary_test.py", # Blocked by trt4 installation
# "test/vgg_block_nchw_test.py",
# "test/vgg_block_test.py",
+ "test/memory_alignment_test.py",
],
additional_deps = [
":tf_trt_integration_test_base",
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 646d62483f..6699b71d28 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -45,11 +45,11 @@ using ::tensorflow::strings::StrCat;
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
- AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
+ AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper() override { done_(); }
private:
- tensorflow::AsyncOpKernel::DoneCallback done_;
+ AsyncOpKernel::DoneCallback done_;
};
#define TYPECASE(dt, X, Y) \
@@ -152,7 +152,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
}
}
-void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
+void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
if (!calibration_mode_) {
VLOG(1) << "Executing native engine";
@@ -193,7 +193,7 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
});
}
-void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
AsyncHelper* helper) {
helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
@@ -238,7 +238,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
ExecuteNativeSegment(ctx, helper);
}
-int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
+int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
int num_batch = ctx->input(0).shape().dim_size(0);
int smallest_engine = 0;
for (const auto i : cached_engine_batches_) {
@@ -254,21 +254,20 @@ int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
cached_engine_batches_.push_back(num_batch);
VLOG(1) << "Running with batch size " << num_batch;
} else {
- string s("Engine buffer is full. buffer limit= ");
- StrAppend(&s, max_cached_engines_, ", current entries= ");
- for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
- StrAppend(&s, "Requested batch= ", num_batch);
- LOG(ERROR) << s;
- ctx->SetStatus(tensorflow::errors::ResourceExhausted(
- "Requested batch size is not available and engine cache is full"));
+ string msg =
+ StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
+ ", current entries=");
+ for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
+ StrAppend(&msg, "Requested batch=", num_batch);
+ LOG(WARNING) << msg;
return -1;
}
}
return smallest_engine;
}
-void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
- tensorflow::AsyncOpKernel::DoneCallback done) {
+void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
+ AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
@@ -276,32 +275,52 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
return;
}
const int smallest_engine = GetEngineBatch(ctx);
- if (smallest_engine < 0) return; // GetEngineBatch already set the status.
+ if (smallest_engine < 0) {
+ LOG(WARNING) << "Failed to get engine batch, running native segment";
+ ExecuteNativeSegment(ctx, helper);
+ return;
+ }
const int num_batch = ctx->input(0).shape().dim_size(0);
auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
- << " failed Running native segment";
+ << " failed. Running native segment";
+ ExecuteNativeSegment(ctx, helper);
+ return;
+ }
+ const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
+ engine_ctx_pair.second.get());
+ if (retry) {
+ LOG(WARNING) << "Failed to execute engine, retrying with native segment";
ExecuteNativeSegment(ctx, helper);
return;
}
+}
+bool TRTEngineOp::ExecuteTrtEngine(
+ OpKernelContext* ctx, const int num_batch,
+ nvinfer1::ICudaEngine* trt_engine_ptr,
+ nvinfer1::IExecutionContext* trt_execution_context_ptr) {
+ const bool kRetry = true;
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
- const string inp_name = StrCat(kInputPHName, i);
+ const string input_name = StrCat(kInputPHName, i);
const size_t binding_index =
- trt_engine_ptr->getBindingIndex(inp_name.c_str());
+ trt_engine_ptr->getBindingIndex(input_name.c_str());
+ if (binding_index == -1) {
+ LOG(ERROR) << "Input node not found, at " << input_name;
+ return kRetry;
+ }
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (num_batch != input_shape.dim_size(0)) {
- LOG(ERROR) << "input data inconsistent batch size";
- ctx->SetStatus(tensorflow::errors::FailedPrecondition(
- "Different batch sizes between input tensors"));
- return;
+ LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
+ << " vs " << input_shape.dim_size(0);
+ return kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
@@ -310,14 +329,10 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "FP16 inputs are not supported!"));
- return;
+ return kRetry;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "INT8 inputs are not supported!"));
- return;
+ return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
@@ -325,9 +340,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unknown output TRT data type! ", static_cast<int>(dtype)));
- return;
+ return kRetry;
}
}
@@ -344,20 +357,23 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
- OP_REQUIRES_OK(
- ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
- &output_shape));
+ auto status = TensorShapeUtils::MakeShape(
+ trt_shape.data(), trt_shape.size(), &output_shape);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to get output shape: " << status;
+ return kRetry;
+ }
} else {
- LOG(ERROR) << "output node not found, at " << output_name;
- ctx->SetStatus(tensorflow::errors::Internal("output ", output_name,
- " couldn't be found!"));
- return;
+ LOG(ERROR) << "Output node not found, at " << output_name;
+ return kRetry;
}
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
if (!status.ok()) {
LOG(ERROR) << "Allocating output failed with " << status;
ctx->SetStatus(status);
- return;
+ // Do not retry since we cannot allocate the same output twice.
+ // TODO(aaroey): ideally we should retry, fix this.
+ return !kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
@@ -366,15 +382,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
- LOG(ERROR) << "half size is not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Half outputs are not supported!"));
- return;
+ LOG(WARNING) << "half size is not supported yet!";
+ return kRetry;
case nvinfer1::DataType::kINT8:
- LOG(ERROR) << "int8 is not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "INT8 outputs are not supported!"));
- return;
+ LOG(WARNING) << "int8 is not supported yet!";
+ return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] =
@@ -382,13 +394,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
break;
#endif
default:
- LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unsupported output data type! ", static_cast<int>(dtype)));
- return;
+ LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
+ return kRetry;
}
}
- // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
@@ -396,15 +406,14 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
- auto& trt_execution_context_ptr = engine_ctx_pair.second;
auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
nullptr);
if (!ret) {
- LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
- ctx->SetStatus(tensorflow::errors::Internal(
- "Failed to enqueue batch for TRT engine: ", name()));
+ LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
+ return kRetry;
}
- // sync should be done by TF.
+ // Synchronization will be done by TF.
+ return !kRetry;
}
TRTEngineOp::~TRTEngineOp() {
@@ -424,8 +433,6 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
- ctx->SetStatus(tensorflow::errors::Internal(
- "Can't get device allocator for device ", device->name()));
return nullptr;
}
allocator_.reset(new TRTDeviceAllocator(alloc));
@@ -452,7 +459,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#if NV_TENSORRT_MAJOR > 3
auto allocator = GetAllocator(ctx);
if (allocator == nullptr) {
- // GetAllocator already set the Status.
return null_pair;
}
infer->setGpuAllocator(allocator);
@@ -469,7 +475,9 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
raw_static_engine->createExecutionContext())};
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
- if (max_batch_size < batch_size) return null_pair;
+ if (max_batch_size < batch_size) {
+ return null_pair;
+ }
return engine_map_.at(max_batch_size);
} // static_engine_
@@ -481,7 +489,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#if NV_TENSORRT_MAJOR > 3
allocator = GetAllocator(ctx);
if (allocator == nullptr) {
- // GetAllocator already set the Status.
return null_pair;
}
#endif
@@ -505,9 +512,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
// retry in the future.
engine_map_[batch_size] = {nullptr, nullptr};
}
- LOG(ERROR) << "Engine creation for batch size " << batch_size
- << " failed " << status;
- ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
+ LOG(WARNING) << "Engine creation for batch size " << batch_size
+ << " failed " << status;
return null_pair;
}
VLOG(1) << "Conversion is done";
@@ -519,7 +525,7 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
- tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) {
+ OpKernelContext* ctx, TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 9265250605..59b744e6d3 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -60,6 +60,12 @@ class TRTEngineOp : public AsyncOpKernel {
// Execute replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
+ // Execute the tensorrt engine. Returns whether we need to retry by running
+ // the native segment.
+ bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch,
+ nvinfer1::ICudaEngine* trt_engine_ptr,
+ nvinfer1::IExecutionContext* trt_execution_context_ptr);
+
// Allocate necessary resources for calibration
Status AllocateCalibrationResources(OpKernelContext* ctx,
TRTCalibrationResource** cr);
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
new file mode 100644
index 0000000000..3dd95c6f62
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -0,0 +1,72 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Testing conversion of BatchMatMul in TF-TRT conversion."""
+ dtype = dtypes.float32
+ input_name = "input"
+ input_dims = [2, 15, 15, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
+ with g.device("/GPU:0"):
+ e1 = constant_op.constant(
+ np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype)
+ e2 = constant_op.constant(
+ np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=e1,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv")
+ out = nn.conv2d(
+ input=conv,
+ filter=e2,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv_2")
+ array_ops.squeeze(out, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ num_expected_engines=1,
+ expected_output_dims=(2, 15, 15, 10),
+ allclose_atol=1.e-02,
+ allclose_rtol=1.e-02)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/timeseries/examples/multivariate.py b/tensorflow/contrib/timeseries/examples/multivariate.py
index ed799542fd..e81cb18ad7 100644
--- a/tensorflow/contrib/timeseries/examples/multivariate.py
+++ b/tensorflow/contrib/timeseries/examples/multivariate.py
@@ -80,8 +80,8 @@ def multivariate_train_and_sample(
session=session, steps=1))
next_sample = numpy.random.multivariate_normal(
# Squeeze out the batch and series length dimensions (both 1).
- mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]),
- cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1]))
+ mean=numpy.squeeze(current_prediction["mean"], axis=(0, 1)),
+ cov=numpy.squeeze(current_prediction["covariance"], axis=(0, 1)))
# Update model state so that future predictions are conditional on the
# value we just sampled.
filtering_features = {
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 643a7cc13a..5a7825f29a 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -15,6 +15,7 @@ package(
default_visibility = [
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
+ "//learning/deepmind:__subpackages__",
"//tensorflow:__subpackages__",
],
)
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
index f0fca63db0..da4a95e045 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -11,6 +11,9 @@ service TPUProfiler {
// Starts a profiling session, blocks until it completes, and returns data.
rpc Profile(ProfileRequest) returns (ProfileResponse) {
}
+ // Collects profiling data and returns user-friendly metrics.
+ rpc Monitor(MonitorRequest) returns (MonitorResponse) {
+ }
}
message ProfileOptions {
@@ -104,3 +107,26 @@ message ProfileResponse {
// next-field: 8
}
+
+message MonitorRequest {
+ // Duration for which to profile between each update.
+ uint64 duration_ms = 1;
+
+ // Indicates the level at which we want to monitor. Currently, two levels are
+ // supported:
+ // Level 1: An ultra lightweight mode that captures only some utilization
+ // metrics.
+ // Level 2: More verbose than level 1. Collects utilization metrics, device
+ // information, step time information, etc. Do not use this option if the TPU
+ // host is being very heavily used.
+ int32 monitoring_level = 2;
+
+ // next-field: 3
+}
+
+message MonitorResponse {
+ // Properly formatted string data that can be directly returned back to user.
+ string data = 1;
+
+ // next-field: 2
+}
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index 9150606f5e..2cc17d6d92 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -1,10 +1,12 @@
-syntax = "proto2";
+syntax = "proto3";
package tensorflow.tpu;
+import "google/protobuf/wrappers.proto";
+
message ClippingLimits {
- optional float lower = 1 [default = -inf];
- optional float upper = 2 [default = inf];
+ google.protobuf.FloatValue lower = 1; // -inf if not set
+ google.protobuf.FloatValue upper = 2; // +inf if not set
}
// Get the learning rate from a <yet to be determined> source that can change
@@ -21,18 +23,18 @@ message LearningRate {
}
message AdagradParameters {
- optional float initial_accumulator = 1 [default = 0.];
+ float initial_accumulator = 1;
}
message StochasticGradientDescentParameters {
}
message FtrlParameters {
- optional float l1 = 1 [default = 0.];
- optional float l2 = 2 [default = 0.];
- optional float lr_power = 3 [default = 0.];
- optional float initial_accum = 4 [default = 0.];
- optional float initial_linear = 5 [default = 0.];
+ float l1 = 1;
+ float l2 = 2;
+ float lr_power = 3;
+ float initial_accum = 4;
+ float initial_linear = 5;
}
// The Adam optimizer does not implement hyper-parameter update; use the dynamic
@@ -41,84 +43,84 @@ message FtrlParameters {
// Here, t is the current timestep.
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
message AdamParameters {
- optional float beta1 = 3 [default = 0.];
- optional float beta2 = 4 [default = 0.];
- optional float epsilon = 5 [default = 0.];
- optional float initial_m = 6 [default = 0.];
- optional float initial_v = 7 [default = 0.];
+ float beta1 = 3;
+ float beta2 = 4;
+ float epsilon = 5;
+ float initial_m = 6;
+ float initial_v = 7;
}
message MomentumParameters {
- optional float momentum = 1 [default = 0.];
- optional bool use_nesterov = 2 [default = false];
- optional float initial_accum = 3 [default = 0.];
+ float momentum = 1;
+ bool use_nesterov = 2;
+ float initial_accum = 3;
}
message RmsPropParameters {
- optional float rho = 1 [default = 0.];
- optional float momentum = 2 [default = 0.];
- optional float epsilon = 3 [default = 0.];
- optional float initial_ms = 4 [default = 0.];
- optional float initial_mom = 5 [default = 0.];
+ float rho = 1;
+ float momentum = 2;
+ float epsilon = 3;
+ float initial_ms = 4;
+ float initial_mom = 5;
}
message CenteredRmsPropParameters {
- optional float rho = 1 [default = 0.];
- optional float momentum = 2 [default = 0.];
- optional float epsilon = 3 [default = 0.];
- optional float initial_ms = 4 [default = 0.];
- optional float initial_mom = 5 [default = 0.];
- optional float initial_mg = 6 [default = 0.];
+ float rho = 1;
+ float momentum = 2;
+ float epsilon = 3;
+ float initial_ms = 4;
+ float initial_mom = 5;
+ float initial_mg = 6;
}
message MdlAdagradLightParameters {
- optional float l2 = 1;
- optional float lr_power = 2;
- optional float min_servable_mdl_benefit = 3;
- optional float mdl_mix_in_margin = 4;
- optional float mdl_benefit_rampup_coeff = 5;
- optional float mdl_min_weight = 6;
- optional float benefit_revisit_scale = 7;
- optional float max_event_benefit = 8;
- optional float max_total_benefit = 9;
- optional float mdl_hard_limit = 10;
- optional bool hard_limit_min_benefit = 11;
- optional bool mdl_regularize = 12;
- optional float initial_accumulator = 13;
- optional float initial_weight = 14;
- optional float initial_benefit = 15;
+ float l2 = 1;
+ float lr_power = 2;
+ float min_servable_mdl_benefit = 3;
+ float mdl_mix_in_margin = 4;
+ float mdl_benefit_rampup_coeff = 5;
+ float mdl_min_weight = 6;
+ float benefit_revisit_scale = 7;
+ float max_event_benefit = 8;
+ float max_total_benefit = 9;
+ float mdl_hard_limit = 10;
+ bool hard_limit_min_benefit = 11;
+ bool mdl_regularize = 12;
+ float initial_accumulator = 13;
+ float initial_weight = 14;
+ float initial_benefit = 15;
}
message AdadeltaParameters {
- optional float rho = 1;
- optional float epsilon = 2;
- optional float initial_accumulator = 3 [default = 0.];
- optional float initial_update = 4 [default = 0.];
+ float rho = 1;
+ float epsilon = 2;
+ float initial_accumulator = 3;
+ float initial_update = 4;
}
message ProximalAdagradParameters {
- optional float l1 = 1;
- optional float l2 = 2;
- optional float initial_accumulator = 3;
+ float l1 = 1;
+ float l2 = 2;
+ float initial_accumulator = 3;
}
message OptimizationParameters {
// Learning rate used for updating the embedding layer parameters.
- optional LearningRate learning_rate = 13;
+ LearningRate learning_rate = 13;
reserved 1; // Old learning rate tag.
// Limits to which to clip the weight values after the backward pass; not
// present means no limits are applied.
- optional ClippingLimits clipping_limits = 2;
+ ClippingLimits clipping_limits = 2;
// Limits to which to clip the backward pass gradient before using it for
// updates; not present means no limits are applied.
- optional ClippingLimits gradient_clipping_limits = 7;
+ ClippingLimits gradient_clipping_limits = 7;
// Whether to use gradient accumulation (do two passes over the input
// gradients: one to accumulate them into a temporary array and another to
// apply them using the actual optimization algorithm).
- optional bool use_gradient_accumulation = 15 [default = false];
+ bool use_gradient_accumulation = 15;
// Optimization algorithm parameters; which field is selected determines which
// algorithm to use.
@@ -140,7 +142,7 @@ message OptimizationParameters {
// value vector and any extra accumulators, etc.).
message StateVariableSpecification {
// Parameter name for the state variable.
- optional string name = 1;
+ string name = 1;
// A normal state variable that should be saved and restored in checkpoints
// and used as an input or output to non-debug TensorFlow ops.
@@ -151,7 +153,7 @@ message StateVariableSpecification {
// from users (used for intermediate gradients being accumulated, for
// example).
message FillWithConstant {
- optional double initial_value = 1;
+ double initial_value = 1;
}
// Usage type of this state variable.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 06885bbc25..92c1eaba71 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -314,7 +314,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# Capture the device function stack at the time of first entry
# since that is the stack that will be used outside_compilation.
graph = ops.get_default_graph()
- self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ self._outer_device_function_stack = graph._device_function_stack.copy()
+ # pylint: enable=protected-access
super(TPUReplicateContext, self).Enter()
def HostComputeCore(self):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 7c7c97638e..ee9ad525ee 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -258,7 +258,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
- host_call=None):
+ host_call=None,
+ training_hooks=None,
+ evaluation_hooks=None,
+ prediction_hooks=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls = {}
if eval_metrics is not None:
@@ -266,6 +269,17 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
if host_call is not None:
host_calls['host_call'] = host_call
_OutfeedHostCall.validate(host_calls)
+
+ training_hooks = list(training_hooks or [])
+ evaluation_hooks = list(evaluation_hooks or [])
+ prediction_hooks = list(prediction_hooks or [])
+
+ for hook in training_hooks + evaluation_hooks + prediction_hooks:
+ if not isinstance(hook, session_run_hook.SessionRunHook):
+ raise TypeError(
+ 'All hooks must be SessionRunHook instances, given: {}'.format(
+ hook))
+
return super(TPUEstimatorSpec, cls).__new__(
cls,
mode=mode,
@@ -275,7 +289,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metrics=eval_metrics,
export_outputs=export_outputs,
scaffold_fn=scaffold_fn,
- host_call=host_call)
+ host_call=host_call,
+ training_hooks=training_hooks,
+ evaluation_hooks=evaluation_hooks,
+ prediction_hooks=prediction_hooks)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
@@ -291,6 +308,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
+ hooks = list(hooks or [])
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(
mode=self.mode,
@@ -300,9 +318,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
scaffold=scaffold,
- training_hooks=hooks,
- evaluation_hooks=hooks,
- prediction_hooks=hooks)
+ training_hooks=self.training_hooks + hooks,
+ evaluation_hooks=self.evaluation_hooks + hooks,
+ prediction_hooks=self.prediction_hooks + hooks)
class _OpQueueContext(object):
@@ -791,7 +809,15 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
is_dataset = inputs.is_dataset
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
- raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.')
+ if not is_dataset:
+ raise TypeError(
+ 'For mode PREDICT, `input_fn` must return `Dataset` instead of '
+ '`features` and `labels`.')
+
+ inputs = _InputsWithStoppingSignals(
+ dataset=inputs.dataset,
+ batch_size=ctx.batch_size_for_input_fn,
+ add_padding=True)
if is_dataset:
hooks.append(inputs.dataset_initializer_hook())
@@ -810,6 +836,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
"""Generates enqueue ops for all the hosts."""
broadcasted_inputs = []
flattened_inputs = None # Cache result from input_fn.
+ signals = None
for host_id in xrange(num_hosts):
with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
for _ in xrange(ctx.num_of_replicas_per_host):
@@ -819,11 +846,13 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
# hosts).
if flattened_inputs is None:
features, labels = inputs.features_and_labels() # Calls get_next()
+ signals = inputs.signals()
+
inputs_structure_recorder.validate_and_record_structure(
- features, labels)
+ features, labels, signals)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
- features, labels))
+ features, labels, signals))
broadcasted_inputs.append(flattened_inputs)
infeed_queue = tpu_feed.InfeedQueue(
@@ -833,7 +862,14 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
broadcasted_inputs,
tpu_ordinal_function=tpu_ordinal_function_impl,
placement_function=device_function_impl)
- return enqueue_ops
+
+ if signals is None:
+ return enqueue_ops
+ else:
+ return {
+ 'ops': enqueue_ops,
+ 'signals': signals,
+ }
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
@@ -1080,9 +1116,11 @@ class _InputPipeline(object):
all_hooks.extend(hooks)
if is_dataset:
run_infeed_loop_on_coordinator = False
- enqueue_ops.append(
- _wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ wrap_fn = (
+ _wrap_computation_in_while_loop
+ if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
+ _wrap_computation_in_while_loop_with_stopping_signals)
+ enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
@@ -1200,6 +1238,7 @@ class _ModelFnWrapper(object):
host_call = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_training_hooks = _CapturedObject()
def train_step(loss):
"""Training step function for use inside a while loop."""
@@ -1216,6 +1255,8 @@ class _ModelFnWrapper(object):
else:
captured_scaffold_fn.capture(None)
+ captured_training_hooks.capture(estimator_spec.training_hooks)
+
# We must run train_op to update the variables prior to running the
# outfeed.
with ops.control_dependencies([train_op]):
@@ -1227,7 +1268,8 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_call_outfeed_ops):
return array_ops.identity(loss)
- return train_step, host_call, captured_scaffold_fn
+ return (train_step, host_call, captured_scaffold_fn,
+ captured_training_hooks)
def convert_to_single_tpu_eval_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single eval step on TPU.
@@ -1257,6 +1299,7 @@ class _ModelFnWrapper(object):
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_eval_hooks = _CapturedObject()
def eval_step(total_loss):
"""Evaluation step function for use inside a while loop."""
@@ -1271,6 +1314,8 @@ class _ModelFnWrapper(object):
loss = tpu_estimator_spec.loss
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
+ captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)
+
to_record = {}
if tpu_estimator_spec.eval_metrics:
to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
@@ -1283,7 +1328,7 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_calls.create_enqueue_op()):
return math_ops.add(total_loss, loss)
- return eval_step, host_calls, captured_scaffold_fn
+ return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
def convert_to_single_tpu_predict_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single predict step on TPU.
@@ -1298,6 +1343,7 @@ class _ModelFnWrapper(object):
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_predict_hooks = _CapturedObject()
def predict_step(unused_scalar_stopping_signal):
"""Evaluation step function for use inside a while loop."""
@@ -1318,6 +1364,7 @@ class _ModelFnWrapper(object):
self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
+ captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)
to_record = {}
identity_fn = lambda **kwargs: kwargs
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
@@ -1329,7 +1376,8 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_calls.create_enqueue_op()):
return _StopSignals.as_scalar_stopping_signal(stopping_signals)
- return predict_step, host_calls, captured_scaffold_fn
+ return (predict_step, host_calls, captured_scaffold_fn,
+ captured_predict_hooks)
def _verify_tpu_spec_predictions(self, predictions):
"""Validates TPUEstimatorSpec.predictions dict."""
@@ -1451,11 +1499,9 @@ class _ModelFnWrapper(object):
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
if estimator_spec.training_chief_hooks:
- raise ValueError(err_msg.format('training_chief_hooks'))
- if estimator_spec.training_hooks:
- raise ValueError(err_msg.format('training_hooks'))
- if estimator_spec.evaluation_hooks:
- raise ValueError(err_msg.format('evaluation_hooks'))
+ raise ValueError(
+ err_msg.format('training_chief_hooks') + 'If you want' +
+ ' to pass training hooks, please pass via training_hooks.')
if estimator_spec.scaffold:
logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
@@ -1937,10 +1983,9 @@ class TPUEstimator(estimator_lib.Estimator):
"""Constructs an `TPUEstimator` instance.
Args:
- model_fn: Model function as required by `Estimator`. For training, the
- returned `EstimatorSpec` cannot have hooks as it is not supported in
- `TPUEstimator`. Instead, the user can pass the training hooks as
- an argument to `TPUEstimator.train()`.
+ model_fn: Model function as required by `Estimator` which returns
+ EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
+ and `prediction_hooks` must not capure any TPU Tensor inside the model_fn.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
@@ -2409,7 +2454,7 @@ class TPUEstimator(estimator_lib.Estimator):
graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
if mode == model_fn_lib.ModeKeys.TRAIN:
- loss, host_call, scaffold = (
+ loss, host_call, scaffold, training_hooks = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
@@ -2464,6 +2509,9 @@ class TPUEstimator(estimator_lib.Estimator):
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
+ if training_hooks:
+ hooks.extend(training_hooks)
+
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
@@ -2475,6 +2523,7 @@ class TPUEstimator(estimator_lib.Estimator):
checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
chief_hooks.append(checkpoint_hook)
+
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
update_ops = _sync_variables_ops()
@@ -2494,7 +2543,7 @@ class TPUEstimator(estimator_lib.Estimator):
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL:
- total_loss, host_calls, scaffold = _eval_on_tpu_system(
+ total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
mean_loss = math_ops.div(total_loss,
@@ -2538,6 +2587,9 @@ class TPUEstimator(estimator_lib.Estimator):
rendezvous=self._rendezvous[mode]),
] + input_hooks
+ if eval_hooks:
+ hooks.extend(eval_hooks)
+
return model_fn_lib.EstimatorSpec(
mode,
loss=mean_loss,
@@ -2548,8 +2600,9 @@ class TPUEstimator(estimator_lib.Estimator):
# Predict
assert mode == model_fn_lib.ModeKeys.PREDICT
- dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system(
- ctx, model_fn_wrapper, dequeue_fn)
+ (dummy_predict_op, host_calls,
+ scaffold, prediction_hooks) = _predict_on_tpu_system(
+ ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
internal_ops_to_run = _sync_variables_ops()
with ops.control_dependencies(internal_ops_to_run):
@@ -2605,6 +2658,9 @@ class TPUEstimator(estimator_lib.Estimator):
ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]),
] + input_hooks
+ if prediction_hooks:
+ hooks.extend(prediction_hooks)
+
return model_fn_lib.EstimatorSpec(
mode,
prediction_hooks=hooks,
@@ -2688,8 +2744,8 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_eval_step, host_calls, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
+ (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
+ ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)
def multi_tpu_eval_steps_on_single_shard():
return training_loop.repeat(
@@ -2704,15 +2760,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
- return loss, host_calls, scaffold
+ return loss, host_calls, scaffold, captured_eval_hooks.get()
def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_train_step, host_call, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
+ (single_tpu_train_step, host_call, captured_scaffold_fn,
+ captured_training_hooks) = (
+ model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
def multi_tpu_train_steps_on_single_shard():
return training_loop.repeat(
@@ -2727,15 +2784,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
- return loss, host_call, scaffold
+ return loss, host_call, scaffold, captured_training_hooks.get()
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
num_cores = ctx.num_cores
- single_tpu_predict_step, host_calls, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn))
+ (single_tpu_predict_step, host_calls, captured_scaffold_fn,
+ captured_predict_hooks
+ ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
def multi_tpu_predict_steps_on_single_shard():
@@ -2755,7 +2813,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
- return dummy_predict_op, host_calls, scaffold
+ return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get()
def _wrap_computation_in_while_loop(device, op_fn):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 84555b60da..35a112e834 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2925,6 +2925,14 @@ tf_cuda_library(
)
cc_library(
+ name = "session_ref",
+ srcs = ["common_runtime/session_ref.cc"],
+ hdrs = ["common_runtime/session_ref.h"],
+ copts = tf_copts(),
+ deps = [":core_cpu_base"],
+)
+
+cc_library(
name = "gpu_id",
hdrs = [
"common_runtime/gpu/gpu_id.h",
diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt
new file mode 100644
index 0000000000..75df90f570
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt
@@ -0,0 +1,78 @@
+op {
+ graph_op_name: "NonMaxSuppressionV4"
+ in_arg {
+ name: "boxes"
+ description: <<END
+A 2-D float tensor of shape `[num_boxes, 4]`.
+END
+ }
+ in_arg {
+ name: "scores"
+ description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+ }
+ in_arg {
+ name: "max_output_size"
+ description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+ }
+ in_arg {
+ name: "iou_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too much with respect to IOU.
+END
+ }
+ in_arg {
+ name: "score_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ description: <<END
+If true, the output `selected_indices` is padded to be of length
+`max_output_size`. Defaults to false.
+END
+ }
+ out_arg {
+ name: "selected_indices"
+ description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+ }
+ out_arg {
+ name: "valid_outputs"
+ description: <<END
+A 0-D integer tensor representing the number of valid elements in
+`selected_indices`, with the valid elements appearing first.
+END
+ }
+ summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+ description: <<END
+pruning away boxes that have high intersection-over-union (IOU) overlap
+with previously selected boxes. Bounding boxes with score less than
+`score_threshold` are removed. Bounding boxes are supplied as
+[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
+diagonal pair of box corners and the coordinates can be provided as normalized
+(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
+is agnostic to where the origin is in the coordinate system and more
+generally is invariant to orthogonal transformations and translations
+of the coordinate system; thus translating or reflections of the coordinate
+system result in the same boxes being selected by the algorithm.
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes. The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`. For example:
+ selected_indices = tf.image.non_max_suppression_v2(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt
new file mode 100644
index 0000000000..be6caacd00
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "NonMaxSuppressionV4"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
index 46142d5923..e1c6b21939 100644
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ b/tensorflow/core/common_runtime/broadcaster.cc
@@ -27,13 +27,14 @@ namespace tensorflow {
namespace {
// Key to be used for BufRendezvous by Broadcaster.
-string BroadcastBufKey(const string& exec_key, int src_rank, int dst_rank) {
+string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
+ int dst_rank) {
if (READABLE_KEYS) {
- return strings::StrCat("broadcast(", exec_key, "):src(", src_rank, "):dst(",
- dst_rank, ")");
+ return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
+ "):src(", src_rank, "):dst(", dst_rank, ")");
} else {
// TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash.
- return strings::StrCat(exec_key, ":", src_rank, ":", dst_rank);
+ return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
}
}
} // namespace
@@ -85,11 +86,15 @@ void Broadcaster::Run(StatusCallback done) {
// device, no send to it is necessary.
/* static*/
-int Broadcaster::TreeRecvFrom(const CollectiveParams& cp) {
- DCHECK_EQ(1, cp.subdiv_rank.size());
- if (cp.is_source) return -1;
- int source_rank = cp.instance.impl_details.subdiv_source_rank[0];
- int my_rank = cp.subdiv_rank[0];
+int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return -1;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+ if (my_rank == source_rank) return -1;
if (source_rank == 0) {
return (my_rank - 1) / 2;
} else {
@@ -99,13 +104,24 @@ int Broadcaster::TreeRecvFrom(const CollectiveParams& cp) {
}
/* static */
-void Broadcaster::TreeSendTo(const CollectiveParams& cp,
+void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv,
std::vector<int>* targets) {
- DCHECK_EQ(1, cp.subdiv_rank.size());
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+
+ int group_size = 0;
+ for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
+ if (impl.subdiv_permutations[subdiv][i] >= 0) {
+ group_size++;
+ }
+ }
+
targets->clear();
- int my_rank = cp.subdiv_rank[0];
- DCHECK_EQ(1, cp.instance.impl_details.subdiv_source_rank.size());
- int source_rank = cp.instance.impl_details.subdiv_source_rank[0];
int successor_rank = 0;
if (source_rank == 0) {
successor_rank = (2 * my_rank) + 1;
@@ -116,108 +132,147 @@ void Broadcaster::TreeSendTo(const CollectiveParams& cp,
if (cp.is_source && source_rank != 0) {
// The source sends to rank 0,1 in addition to its positional
// descendants.
- if (cp.group.group_size > 1) {
+ if (group_size > 1) {
targets->push_back(0);
}
- if (cp.group.group_size > 2 && source_rank != 1) {
+ if (group_size > 2 && source_rank != 1) {
targets->push_back(1);
}
}
for (int i = 0; i < 2; ++i) {
- if (successor_rank < cp.group.group_size && successor_rank != source_rank) {
+ if (successor_rank < group_size && successor_rank != source_rank) {
targets->push_back(successor_rank);
}
++successor_rank;
}
}
-// Execute a tree broadcast, i.e. each non-source device receives from
-// one other and sends to up-to two others.
+// Executes a hierarchical tree broadcast.
+// Each subdiv is a broadcast between a subset of the devices.
+// If there is only one task, there is one subdiv comprising a broadcast between
+// all devices belonging to the task.
+// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
+// subdiv, one device from each task participates in a binary tree broadcast.
+// Each task receives a copy of the tensor on one device via this broadcast.
+// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
+// corresponds to broadcast between all devices on task i. Thus, each task
+// participates in at most 2 subdivs.
void Broadcaster::RunTree() {
- mutex mu; // also guards status_ while callbacks are pending
- int pending_count = 0; // GUARDED_BY(mu)
- condition_variable all_done;
- std::vector<int> send_to_ranks;
- TreeSendTo(col_params_, &send_to_ranks);
-
- if (!is_source_) {
- // Begin by receiving the value.
- int recv_from_rank = TreeRecvFrom(col_params_);
- Notification note;
- DispatchRecv(recv_from_rank, output_,
- [this, recv_from_rank, &mu, &note](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- note.Notify();
- });
- note.WaitForNotification();
- }
+ int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
+ // TODO(ayushd): this is easily improved when a node participates in both
+ // first and second subdivision. It would first send to its descendents in
+ // the first subdiv, then wait until all pending ops are finished before
+ // sending to descendents in second subdiv. A better implementation would
+ // collapse the two send blocks.
+ for (int si = 0; si < num_subdivs; si++) {
+ int my_rank = col_params_.subdiv_rank[si];
+ // If rank is -1, this device does not participate in this subdiv.
+ if (-1 == my_rank) continue;
+ int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si];
+ if (VLOG_IS_ON(1)) {
+ string subdiv_buf;
+ for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) {
+ strings::StrAppend(&subdiv_buf, r, ",");
+ }
+ VLOG(1) << "Running Broadcast tree device=" << device_->name()
+ << " subdiv=" << si << " perm=" << subdiv_buf
+ << " my_rank=" << my_rank << " source_rank=" << source_rank;
+ }
+
+ mutex mu; // also guards status_ while callbacks are pending
+ int pending_count = 0; // GUARDED_BY(mu)
+ condition_variable all_done;
- // Then forward value to all descendent devices.
- if (status_.ok()) {
- for (int i = 0; i < send_to_ranks.size(); ++i) {
- int target_rank = send_to_ranks[i];
- {
- mutex_lock l(mu);
- ++pending_count;
+ if (my_rank >= 0 && my_rank != source_rank) {
+ // Begin by receiving the value.
+ int recv_from_rank = TreeRecvFrom(col_params_, si);
+ Notification note;
+ DispatchRecv(si, recv_from_rank, my_rank, output_,
+ [this, &mu, &note](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ note.Notify();
+ });
+ note.WaitForNotification();
+ }
+
+ // Then forward value to all descendent devices.
+ if (my_rank >= 0 && status_.ok()) {
+ std::vector<int> send_to_ranks;
+ TreeSendTo(col_params_, si, &send_to_ranks);
+ for (int i = 0; i < send_to_ranks.size(); ++i) {
+ int target_rank = send_to_ranks[i];
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DispatchSend(si, target_rank, my_rank,
+ (is_source_ ? &ctx_->input(0) : output_),
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (pending_count == 0) {
+ all_done.notify_all();
+ }
+ });
}
- DispatchSend(
- target_rank, (is_source_ ? &ctx_->input(0) : output_),
- [this, target_rank, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (pending_count == 0) {
- all_done.notify_all();
- }
- });
}
- }
- if (status_.ok() && is_source_) {
- // Meanwhile, copy input to output if we weren't lucky enough to
- // be able to reuse input as output.
- const Tensor* input = &ctx_->input(0);
- if (input != output_ &&
- (DMAHelper::base(input) != DMAHelper::base(output_))) {
- {
- mutex_lock l(mu);
- ++pending_count;
+ // For the original source device, we copy input to output if they are
+ // different.
+ // If there is only 1 subdiv, we do this in that subdiv. If there is more
+ // than 1 subdiv, then the original source device will participate in 2
+ // subdivs - the global inter-task broadcast and one local intra-task
+ // broadcast. In this case, we perform the copy in the second subdiv for
+ // this device.
+ if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
+ VLOG(2) << "copying input to output for device=" << device_->name()
+ << " subdiv=" << si;
+ const Tensor* input = &ctx_->input(0);
+ if (input != output_ &&
+ (DMAHelper::base(input) != DMAHelper::base(output_))) {
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DeviceContext* op_dev_ctx = ctx_->op_device_context();
+ CollectiveRemoteAccessLocal::MemCpyAsync(
+ op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
+ ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (0 == pending_count) {
+ all_done.notify_all();
+ }
+ });
}
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- CollectiveRemoteAccessLocal::MemCpyAsync(
- op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
- ctx_->output_alloc_attr(0), input, output_, 0 /*steam_index*/,
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (0 == pending_count) {
- all_done.notify_all();
- }
- });
}
- }
- // Then wait for all pending actions to complete.
- {
- mutex_lock l(mu);
- if (pending_count > 0) {
- all_done.wait(l);
+ // Then wait for all pending actions to complete.
+ {
+ mutex_lock l(mu);
+ if (pending_count > 0) {
+ all_done.wait(l);
+ }
}
}
-
- VLOG(2) << "return status " << status_;
+ VLOG(2) << "device=" << device_->name() << " return status " << status_;
done_(status_);
}
-void Broadcaster::DispatchSend(int dst_rank, const Tensor* src_tensor,
+void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank,
+ const Tensor* src_tensor,
const StatusCallback& done) {
- string send_buf_key = BroadcastBufKey(exec_key_, rank_, dst_rank);
- VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
- << device_->name();
+ string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
int dst_idx =
- col_params_.instance.impl_details.subdiv_permutations[0][dst_rank];
+ col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank];
+ VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
+ << device_->name() << " to_device "
+ << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv
+ << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx],
col_params_.instance.task_names[dst_idx], send_buf_key,
device_, ctx_->op_device_context(),
@@ -225,15 +280,15 @@ void Broadcaster::DispatchSend(int dst_rank, const Tensor* src_tensor,
device_locality_, done);
}
-void Broadcaster::DispatchRecv(int src_rank, Tensor* dst_tensor,
- const StatusCallback& done) {
- string recv_buf_key = BroadcastBufKey(exec_key_, src_rank, rank_);
+void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank,
+ Tensor* dst_tensor, const StatusCallback& done) {
+ string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
int src_idx =
- col_params_.instance.impl_details.subdiv_permutations[0][src_rank];
+ col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank];
VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device "
- << col_params_.instance.device_names[src_idx];
- int dst_idx = col_params_.instance.impl_details.subdiv_permutations[0][rank_];
- CHECK_EQ(col_params_.instance.device_names[dst_idx], device_->name());
+ << col_params_.instance.device_names[src_idx] << " to_device "
+ << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank
+ << " src_idx=" << src_idx;
col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx],
col_params_.instance.task_names[src_idx],
col_params_.task.is_local[src_idx], recv_buf_key,
diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/broadcaster.h
index bdf68f19ab..799228b161 100644
--- a/tensorflow/core/common_runtime/broadcaster.h
+++ b/tensorflow/core/common_runtime/broadcaster.h
@@ -34,17 +34,24 @@ class Broadcaster {
// Returns the rank of the device from which this device should receive
// its value, -1 if no value should be received.
- static int TreeRecvFrom(const CollectiveParams& cp);
+ static int TreeRecvFrom(const CollectiveParams& cp, int subdiv);
// Populates targets with the ranks of the devices to which this device
// should forward the value.
- static void TreeSendTo(const CollectiveParams& cp, std::vector<int>* targets);
+ static void TreeSendTo(const CollectiveParams& cp, int subdiv,
+ std::vector<int>* targets);
private:
- void DispatchSend(int dst_rank, const Tensor* src_tensor,
- const StatusCallback& done);
- void DispatchRecv(int src_rank, Tensor* dst_tensor,
+ // Sends `src_tensor` asynchronously from this device to device at `dst_rank`
+ // in `subdiv`. Calls `done` upon completion.
+ void DispatchSend(int subdiv, int dst_rank, int src_rank,
+ const Tensor* src_tensor, const StatusCallback& done);
+ // Receives a tensor into the memory buffer owned by `dst_tensor` at this
+ // device from device at `src_rank` in `subdiv`. Calls `done` upon
+ // completion.
+ void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
const StatusCallback& done);
+ // Executes the hierarchical broadcast defined by this op.
void RunTree();
Status status_;
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/broadcaster_test.cc
index 6a163a0db0..3960fc6c97 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/broadcaster_test.cc
@@ -38,7 +38,6 @@ namespace tensorflow {
namespace {
static int64 kStepId = 123;
-static int32 kNumSubdivs = 1; // Subdiv not yet meaningful for broadcast
// The test harness won't allow a mixture of fixture and non-fixture
// tests in one file, so this is a trival fixture for tests that don't
@@ -59,12 +58,14 @@ class TrivialTest : public ::testing::Test {
CollectiveParams cp; \
cp.group.group_size = D; \
cp.instance.impl_details.subdiv_source_rank = {S}; \
+ cp.instance.impl_details.subdiv_permutations.push_back( \
+ std::vector<int>(D, 0)); \
cp.subdiv_rank = {R}; \
cp.is_source = (S == R); \
- EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp)); \
+ EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp, 0)); \
std::vector<int> expected = ST; \
std::vector<int> send_to; \
- Broadcaster::TreeSendTo(cp, &send_to); \
+ Broadcaster::TreeSendTo(cp, 0, &send_to); \
ASSERT_EQ(expected.size(), send_to.size()); \
for (int i = 0; i < expected.size(); ++i) { \
EXPECT_EQ(expected[i], send_to[i]); \
@@ -209,8 +210,11 @@ class BroadcasterTest : public ::testing::Test {
#endif
}
- void Init(int num_workers, int num_devices, DataType dtype,
+ void Init(int num_workers, int num_devices_per_worker, DataType dtype,
const DeviceType& device_type, int fail_after) {
+ VLOG(2) << "num_workers=" << num_workers
+ << " num_devices_per_worker=" << num_devices_per_worker;
+ int total_num_devices = num_workers * num_devices_per_worker;
device_type_ = device_type;
std::vector<Device*> local_devices;
SessionOptions sess_opts;
@@ -218,14 +222,14 @@ class BroadcasterTest : public ::testing::Test {
Bytes mem_limit(4 << 20);
DeviceLocality dev_locality;
for (int wi = 0; wi < num_workers; ++wi) {
- for (int di = 0; di < num_devices; ++di) {
+ for (int di = 0; di < num_devices_per_worker; ++di) {
if (device_type == DEVICE_CPU) {
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
"/device:CPU:", di);
local_devices.push_back(new ThreadPoolDevice(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
- int dev_idx = (wi * num_devices) + di;
+ int dev_idx = (wi * num_devices_per_worker) + di;
if (dev_idx >= static_cast<int>(gpu_devices_.size())) {
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node.";
@@ -247,67 +251,86 @@ class BroadcasterTest : public ::testing::Test {
dev_mgr_.get());
col_params_.name = "test_collective";
col_params_.instance.data_type = dtype;
- static const int kGroupKey = 5;
+ static const int kGroupKey = 6;
col_params_.group.group_key = kGroupKey;
- static const int kInstanceKey = 17;
+ static const int kInstanceKey = 18;
col_params_.instance.instance_key = kInstanceKey;
col_params_.group.device_type = device_type;
- col_params_.group.group_size = num_workers * num_devices;
+ col_params_.group.group_size = num_workers * num_devices_per_worker;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = BROADCAST_COLLECTIVE;
- col_params_.instance.impl_details.subdiv_permutations.resize(kNumSubdivs);
- col_params_.subdiv_rank.resize(kNumSubdivs);
- int subdiv_stride = num_devices / kNumSubdivs;
- for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
- col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
- subdiv_stride);
- col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
- }
- // Set up a local device ring order that's not just 0,1,2...
- std::vector<int> local_ring_order;
- for (int di = 0; di < num_devices; ++di) {
- local_ring_order.push_back(di);
+ int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0);
+ VLOG(2) << "#subdiv=" << num_subdivs;
+ col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
+ col_params_.subdiv_rank.resize(num_subdivs);
+
+ // Inter-machine broadcast.
+ int subdiv_i = 0;
+ if (num_workers > 1) {
+ col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
+ total_num_devices, -1);
+ for (int i = 0, rank = 0; i < total_num_devices; i++) {
+ if (i % num_devices_per_worker == 0) {
+ col_params_.instance.impl_details
+ .subdiv_permutations[subdiv_i][rank] = i;
+ rank++;
+ }
+ }
+ if (VLOG_IS_ON(2)) {
+ string sp_buf;
+ for (int p :
+ col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
+ strings::StrAppend(&sp_buf, p, ", ");
+ VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
+ }
+ subdiv_i++;
}
- for (int di = 0; di < num_devices; ++di) {
- bool is_odd = ((di % 2) == 1);
- int other = (di + (is_odd ? 7 : 3)) % num_devices;
- if (di == other) continue;
- iter_swap(local_ring_order.begin() + di,
- local_ring_order.begin() + other);
+ // Intra-machine broadcast.
+ for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) {
+ col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
+ total_num_devices, -1);
+ int perm_i_base = i * num_devices_per_worker;
+ VLOG(2) << "subdiv_i=" << subdiv_i << " i=" << i
+ << " perm_i_base=" << perm_i_base << " subdiv_perms.size="
+ << col_params_.instance.impl_details.subdiv_permutations.size();
+ // subdiv for worker i.
+ for (int j = perm_i_base, rank = 0;
+ j < perm_i_base + num_devices_per_worker; j++, rank++) {
+ col_params_.instance.impl_details.subdiv_permutations[subdiv_i][rank] =
+ j;
+ }
+ if (VLOG_IS_ON(2)) {
+ string sp_buf;
+ for (int p :
+ col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
+ strings::StrAppend(&sp_buf, p, ", ");
+ VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
+ }
}
- broadcast_dev_id_ = local_ring_order[0];
- string lro_buf;
- for (auto d : local_ring_order) strings::StrAppend(&lro_buf, d, ", ");
- VLOG(1) << "local_ring_order " << lro_buf;
- // Set up all of the fake device contexts.
- for (int wi = 0; wi < num_workers; ++wi) {
- for (int di = 0; di < num_devices; ++di) {
+ // Set up all the fake device contexts.
+ for (int wi = 0; wi < num_workers; wi++) {
+ for (int di = 0; di < num_devices_per_worker; di++) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
- string dev_name = strings::StrCat(task_name, "/device:CPU:", di);
+ string dev_name;
if (device_type == DEVICE_GPU) {
dev_name = strings::StrCat(task_name, "/device:GPU:0");
+ } else {
+ dev_name = strings::StrCat(task_name, "/device:CPU:", di);
}
+ VLOG(2) << "dev=" << dev_name;
col_params_.instance.device_names.push_back(dev_name);
col_params_.instance.task_names.push_back(task_name);
- // Normally each device would set is_local to its own perspective but
- // this test runs in a single process so is_local is always true.
col_params_.task.is_local.push_back(true);
- for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
- int rotated_di =
- (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
- num_devices;
- col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
- wi * num_devices + local_ring_order[rotated_di]);
- }
}
}
- for (int wi = 0; wi < num_workers; ++wi) {
- for (int di = 0; di < num_devices; ++di) {
- int rank = wi * num_devices + di;
+ for (int wi = 0; wi < num_workers; wi++) {
+ for (int di = 0; di < num_devices_per_worker; di++) {
+ int default_rank = wi * num_devices_per_worker + di;
instances_.push_back(new DeviceInstance(
- rank, col_params_.instance.device_names[rank], device_type_, this));
+ default_rank, col_params_.instance.device_names[default_rank],
+ device_type, this));
}
}
}
@@ -315,6 +338,7 @@ class BroadcasterTest : public ::testing::Test {
typedef std::function<void(Tensor*)> InitFunc;
void Broadcast(bool forward_input) {
+ VLOG(2) << "#instances=" << instances_.size();
std::atomic<int> done(0);
for (auto di : instances_) {
SchedClosure([di, forward_input, &done] {
@@ -516,39 +540,29 @@ class BroadcasterTest : public ::testing::Test {
CHECK_EQ(group_size, col_params_.instance.device_names.size());
// Default rank is order in device_names.
col_params_.default_rank = rank;
- // perm_rank is order in subdiv[0]:
- int perm_rank = -1;
- for (int i = 0;
- i < col_params_.instance.impl_details.subdiv_permutations[0].size();
- ++i) {
- if (rank ==
- col_params_.instance.impl_details.subdiv_permutations[0][i]) {
- perm_rank = i;
- break;
- }
- }
- CHECK_GE(perm_rank, 0);
- col_params_.instance.impl_details.subdiv_source_rank.resize(1, 0);
- col_params_.is_source =
- (perm_rank ==
- col_params_.instance.impl_details.subdiv_source_rank[0]);
- // Set rank in all subdivs by finding that default_rank.
- for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
- for (int r = 0;
- r <
- col_params_.instance.impl_details.subdiv_permutations[sdi].size();
- ++r) {
- if (col_params_.default_rank ==
- col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
- col_params_.subdiv_rank[sdi] = r;
- CHECK_EQ(0, sdi);
- CHECK_EQ(perm_rank, col_params_.subdiv_rank[sdi]);
+
+ auto& impl = col_params_.instance.impl_details;
+ size_t num_subdivs = impl.subdiv_permutations.size();
+ impl.subdiv_source_rank.resize(num_subdivs, 0);
+ col_params_.subdiv_rank.resize(num_subdivs);
+ for (size_t si = 0; si < num_subdivs; si++) {
+ int perm_rank = -1;
+ for (int i = 0; i < group_size; i++) {
+ if (rank == impl.subdiv_permutations[si][i]) {
+ perm_rank = i;
break;
}
}
+ col_params_.subdiv_rank[si] = perm_rank;
+ }
+ string rank_buf;
+ for (int r : col_params_.subdiv_rank) {
+ strings::StrAppend(&rank_buf, r, ", ");
}
- CHECK_EQ(group_size, col_params_.task.is_local.size());
- CHECK_EQ(group_size, col_params_.instance.task_names.size());
+ VLOG(1) << "default=" << rank << " subdiv_ranks=" << rank_buf;
+
+ col_params_.is_source =
+ col_params_.subdiv_rank[0] == impl.subdiv_source_rank[0];
}
void InitTensor(DataType dtype, const TensorShape& shape,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 236f999228..2a14493a67 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -319,6 +319,97 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
}
} // namespace
+int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) {
+ int num_tasks = static_cast<int>(dev_per_task.size());
+ int task_lo = 0;
+ int task_hi;
+ for (int ti = 0; ti < num_tasks; ti++) {
+ task_hi = task_lo + dev_per_task[ti];
+ if (task_lo <= device_rank && device_rank < task_hi) return ti;
+ task_lo += dev_per_task[ti];
+ }
+ LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
+ << " devices";
+ return -1;
+}
+
+void CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
+ const string& device, int source_rank, const std::vector<int>& dev_per_task,
+ CollectiveParams* cp) {
+ if (VLOG_IS_ON(1)) {
+ string dpt_buf;
+ for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
+ VLOG(1) << "GenerateBcastSubdivPerms device=" << device
+ << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf;
+ }
+ int num_tasks = cp->group.num_tasks;
+ // If there is just 1 task, then execute binary tree broadcast over all
+ // devices. Otherwise, the first subdiv is inter-task broadcast, and then
+ // there are N more subdivs, where N is #task.
+ int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
+ int total_num_devices = 0;
+ for (int num_dev : dev_per_task) total_num_devices += num_dev;
+
+ cp->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+ cp->subdiv_rank.reserve(num_subdivs);
+ cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
+
+ // Inter-task subdiv. Pick one device from each task - this is the source
+ // device if it belongs to that task, or device 0 for that task. If a device
+ // does not participate in the subdiv, set subdiv_rank to -1.
+ if (num_tasks > 1) {
+ const int sdi = 0;
+ std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int device_count = 0;
+ int source_task = GetDeviceTask(source_rank, dev_per_task);
+ for (int ti = 0; ti < cp->group.num_tasks; ti++) {
+ bool participate = false;
+ if (source_task == ti) {
+ // Source device belongs to this task.
+ perm.push_back(source_rank);
+ participate = cp->instance.device_names[source_rank] == device;
+ } else {
+ // Source does not belong to this task, choose dev 0.
+ perm.push_back(device_count);
+ participate = cp->instance.device_names[device_count] == device;
+ }
+ if (participate) cp->subdiv_rank.push_back(ti);
+ device_count += dev_per_task[ti];
+ }
+ if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1);
+ cp->instance.impl_details.subdiv_source_rank.push_back(source_task);
+ }
+
+ // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
+ // source to dev 0 for that task if it does not contain original source, else
+ // set to rank of original source. If a device does not participate in the
+ // subdiv, set subdiv_rank to -1;
+ int abs_di = 0;
+ for (int ti = 0; ti < cp->group.num_tasks; ti++) {
+ const int sdi = ti + (num_tasks > 1 ? 1 : 0);
+ std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ bool participate = false;
+ int subdiv_source = 0;
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ perm.push_back(abs_di);
+ if (cp->instance.device_names[abs_di] == device) {
+ participate = true;
+ cp->subdiv_rank.push_back(di);
+ }
+ if (abs_di == source_rank) subdiv_source = di;
+ abs_di++;
+ }
+ if (!participate) cp->subdiv_rank.push_back(-1);
+ cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source);
+ }
+
+ for (int sri = 0; sri < num_subdivs; sri++) {
+ CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0);
+ }
+}
+
// Establish the requested number of subdivision permutations based on the
// ring order implicit in the device order.
/*static*/
@@ -351,61 +442,51 @@ void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
dev_per_task.push_back(dev_count);
CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
- // Generate a ring permutation for each requested offset.
- CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
- VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
- << &cp->instance.impl_details.subdiv_permutations;
- cp->instance.impl_details.subdiv_permutations.resize(
- cp->instance.impl_details.subdiv_offsets.size());
- cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
- ++sdi) {
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- // A negative subdivision offset is interpreted as follows:
- // 1. Reverse the local device ordering.
- // 2. Begin the subdivision at abs(offset) in the reversed ordering.
- bool reverse = false;
- if (offset < 0) {
- offset = abs(offset);
- reverse = true;
- }
- int prior_dev_count = 0; // sum over prior worker device counts
- for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
- for (int di = 0; di < dev_per_task[ti]; ++di) {
- int di_offset = (di + offset) % dev_per_task[ti];
- int offset_di =
- reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
- // Device index in global subdivision permutation.
- int permuted_di = prior_dev_count + offset_di;
- int rank = static_cast<int>(perm.size());
- perm.push_back(permuted_di);
- if (cp->instance.device_names[permuted_di] == device) {
- CHECK_EQ(permuted_di, cp->default_rank);
- cp->subdiv_rank[sdi] = rank;
- }
- }
- prior_dev_count += dev_per_task[ti];
- }
- CHECK_EQ(cp->group.group_size, perm.size());
- }
-
- if (cp->instance.type == BROADCAST_COLLECTIVE) {
- CHECK_GE(source_rank, 0);
- cp->instance.impl_details.subdiv_source_rank.resize(
- cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_source_rank.size();
+ CHECK(cp->instance.type == REDUCTION_COLLECTIVE ||
+ cp->instance.type == BROADCAST_COLLECTIVE);
+ if (cp->instance.type == REDUCTION_COLLECTIVE) {
+ // Generate a ring permutation for each requested offset.
+ CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+ VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
+ << &cp->instance.impl_details.subdiv_permutations;
+ cp->instance.impl_details.subdiv_permutations.resize(
+ cp->instance.impl_details.subdiv_offsets.size());
+ cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
++sdi) {
- for (int j = 0; j < cp->group.group_size; ++j) {
- if (cp->instance.impl_details.subdiv_permutations[sdi][j] ==
- source_rank) {
- cp->instance.impl_details.subdiv_source_rank[sdi] = j;
- break;
+ std::vector<int>& perm =
+ cp->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = cp->instance.impl_details.subdiv_offsets[sdi];
+ // A negative subdivision offset is interpreted as follows:
+ // 1. Reverse the local device ordering.
+ // 2. Begin the subdivision at abs(offset) in the reversed ordering.
+ bool reverse = false;
+ if (offset < 0) {
+ offset = abs(offset);
+ reverse = true;
+ }
+ int prior_dev_count = 0; // sum over prior worker device counts
+ for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int di_offset = (di + offset) % dev_per_task[ti];
+ int offset_di =
+ reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
+ // Device index in global subdivision permutation.
+ int permuted_di = prior_dev_count + offset_di;
+ int rank = static_cast<int>(perm.size());
+ perm.push_back(permuted_di);
+ if (cp->instance.device_names[permuted_di] == device) {
+ CHECK_EQ(permuted_di, cp->default_rank);
+ cp->subdiv_rank[sdi] = rank;
+ }
}
+ prior_dev_count += dev_per_task[ti];
}
- CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sdi], 0);
+ CHECK_EQ(cp->group.group_size, perm.size());
}
+ } else if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp);
}
if (VLOG_IS_ON(1)) {
@@ -418,13 +499,21 @@ void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
di < cp->instance.impl_details.subdiv_permutations[sdi].size();
++di) {
int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
- strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
+ if (idx >= 0) {
+ CHECK_GT(cp->instance.device_names.size(), idx);
+ strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
+ }
}
strings::StrAppend(&buf, " subdiv_offsets: ");
for (auto o : cp->instance.impl_details.subdiv_offsets)
strings::StrAppend(&buf, o, " ");
strings::StrAppend(&buf, " SubdivRank: ");
for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
+ if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ strings::StrAppend(&buf, " subdiv_source_rank: ");
+ for (auto src : cp->instance.impl_details.subdiv_source_rank)
+ strings::StrAppend(&buf, src, " ");
+ }
VLOG(1) << buf;
}
}
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 01bdeca7d1..2e2aa801d9 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -213,8 +213,16 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
LOCKS_EXCLUDED(irec->out_mu);
friend class CollectiveParamResolverLocalTest;
+ // Establishes the requested number of subdivision permutations based on the
+ // ring order implicit in the device order.
static void GenerateSubdivPerms(const string& device, int source_rank,
CollectiveParams* cp);
+ // Establishes the subdivisions for broadcast op. The first subdiv executes
+ // binary tree bcast with one device per task. Each subsequent subdiv
+ // executes intra-task binary tree broadcast.
+ static void GenerateBcastSubdivPerms(const string& device, int source_rank,
+ const std::vector<int>& dev_per_task,
+ CollectiveParams* cp);
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index d5be8f927e..9ea23b72d2 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -49,6 +49,26 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp);
}
+ // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the
+ // generated subdiv perms, ranks, and source ranks match the expected values.
+ void BcastSubdivPerms(
+ CollectiveParams* cp, const std::vector<int>& dev_per_task,
+ int device_rank, int source_rank,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank,
+ const std::vector<int>& expected_subdiv_source_rank) {
+ cp->subdiv_rank.clear();
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->instance.impl_details.subdiv_source_rank.clear();
+ CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
+ cp->instance.device_names[device_rank], source_rank, dev_per_task, cp);
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ EXPECT_EQ(expected_subdiv_source_rank,
+ cp->instance.impl_details.subdiv_source_rank);
+ }
+
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@@ -216,4 +236,113 @@ TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) {
EXPECT_EQ(1, cp.subdiv_rank[1]);
}
+TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) {
+ CollectiveParams cp;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = 1;
+ cp.instance.type = BROADCAST_COLLECTIVE;
+ for (int i = 0; i < 8; i++) {
+ string dev_name =
+ strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i);
+ cp.instance.device_names.push_back(dev_name);
+ }
+ std::vector<int> dev_per_task = {8};
+
+ // source 0 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
+ {0});
+
+ // source 2 device 2
+ BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2},
+ {2});
+
+ // source 2 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
+ {2});
+}
+
+TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) {
+ CollectiveParams cp;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = 4;
+ cp.instance.type = BROADCAST_COLLECTIVE;
+ for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+ for (int di = 0; di < 8; di++) {
+ string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
+ "/device:GPU:", di);
+ cp.instance.device_names.push_back(dev_name);
+ }
+ }
+ std::vector<int> dev_per_task = {8, 8, 8, 8};
+
+ // source 0 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 0,
+ {{0, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 2,
+ {{2, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 9
+ BcastSubdivPerms(&cp, dev_per_task, 9, 9,
+ {{0, 9, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
+}
+
+TEST_F(CollectiveParamResolverLocalTest,
+ GenerateBcastSubdivPerms4TasksVariableGPU) {
+ CollectiveParams cp;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = 4;
+ std::vector<int> dev_per_task = {4, 4, 6, 8};
+ for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
+ "/device:GPU:", di);
+ cp.instance.device_names.push_back(dev_name);
+ }
+ }
+
+ // source 0 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 0,
+ {{0, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ BcastSubdivPerms(&cp, dev_per_task, 0, 2,
+ {{2, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 5
+ BcastSubdivPerms(&cp, dev_per_task, 5, 9,
+ {{0, 4, 9, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index d1fd930d25..0695278c0d 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
@@ -1223,10 +1224,9 @@ Status DirectSession::CreateExecutors(
item->graph = partition_graph.get();
item->executor = nullptr;
item->device = device;
- Executor* executor;
- TF_RETURN_IF_ERROR(
- NewLocalExecutor(params, std::move(partition_graph), &executor));
- item->executor.reset(executor);
+ auto executor_type = options_.config.experimental().executor_type();
+ TF_RETURN_IF_ERROR(NewExecutor(
+ executor_type, params, std::move(partition_graph), &item->executor));
}
// Cache the mapping from input/output names to graph elements to
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 0c0fbc729c..f97fa4fadc 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -448,6 +448,14 @@ bool IsLocal(EagerContext* ctx, tensorflow::Device* d) {
return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok();
}
+bool OnSameTask(EagerContext* ctx, Device* first, Device* second) {
+ if (first == nullptr) first = ctx->HostCPU();
+ if (second == nullptr) second = ctx->HostCPU();
+ return first->parsed_name().job == second->parsed_name().job &&
+ first->parsed_name().replica == second->parsed_name().replica &&
+ first->parsed_name().task == second->parsed_name().task;
+}
+
Status EagerLocalExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
@@ -689,7 +697,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::Device* input_device;
TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
- if (op->Device() != input_device) {
+ if (op->Device() != input_device &&
+ // If the expected and actual devices are on the same task, don't
+ // explicitly copy, and instead depend on the copy to happen locally
+ // when the op is executed on the device.
+ !OnSameTask(ctx, op->Device(), input_device)) {
// TODO(b/110044833): It's possible the same tensor gets copied to the
// remote device repeatedly.
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 1e837e9a7e..120f480198 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -1019,8 +1019,9 @@ TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
DCHECK_EQ(x.dtype(), DT_INT32);
Tensor y;
HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
- "The node 'add' has inputs from different frames. The input 'enter' "
- "is in frame 'while'. The input 'i' is in frame ''.");
+ "{{node add}} has inputs from different frames. The input"
+ " {{node enter}} is in frame 'while'. The input {{node i}} is in"
+ " frame ''.");
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 29f702699f..94e10dbfa2 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -22,7 +22,6 @@ limitations under the License.
#ifdef INTEL_MKL
#include <cstdlib>
-#include <string>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 6781c87f6c..613470365d 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -41,10 +41,8 @@ namespace {
const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
-// Returns a list of devices sorted by preferred type and then name
-// from 'devices' whose type is in 'supported_device_types'. This
-// function searches the device types in 'supported_device_types' and
-// returns the subset of devices that match.
+// Returns a list of devices having type in supported_device_types. The
+// returned list is sorted by preferred type (higher numeric type is preferred).
std::vector<Device*> FilterSupportedDevices(
const std::vector<Device*>& devices,
const DeviceTypeVector& supported_device_types) {
@@ -81,12 +79,12 @@ std::vector<Device*> FilterSupportedDevices(
// DeviceSet device_set = ...;
// ColocationGraph colocation_graph(graph, device_set);
//
-// // Add all the nodes of graph to colocation_graph.
+// // Add all the nodes of the `graph` to the `colocation_graph`.
// for (Node* node : graph.nodes()) {
// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
// }
//
-// // Add one or more colocation constraint.
+// // Add one or more colocation constraints.
// Node node_1 = *graph.FindNodeId(...);
// Node node_2 = *graph.FindNodeId(...);
// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
@@ -96,9 +94,9 @@ std::vector<Device*> FilterSupportedDevices(
// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
// }
//
-// The implementation uses the union-find algorithm to maintain the
-// connected components efficiently and incrementally as edges
-// (implied by ColocationGraph::ColocateNodes() invocations) are added.
+// This implementation uses the Union-Find algorithm to efficiently maintain the
+// connected components and incrementally adds edges via
+// ColocationGraph::ColocateNodes() invocations.
class ColocationGraph {
public:
ColocationGraph(Graph* graph, const DeviceSet* device_set,
@@ -134,13 +132,9 @@ class ColocationGraph {
std::unordered_map<StringPiece, const Node*, StringPieceHasher>
colocation_group_root;
- for (Node* node : graph_->nodes()) {
- if (!node->IsOp()) {
- continue;
- }
-
- // When adding the node, identify whether it is part of a
- // colocation group.
+ for (Node* node : graph_->op_nodes()) {
+ // When adding the node, identify whether it is part of a colocation
+ // group.
// This code is effectively the equivalent of GetNodeAttr() for a string
// array, but it avoids all internal allocations (the allocation of the
@@ -219,11 +213,10 @@ class ColocationGraph {
Member& x_root_member = members_[x_root];
Member& y_root_member = members_[y_root];
- // Merge the sets by swinging the parent pointer of the smaller
- // tree to point to the root of the larger tree. Together with
- // path compression in ColocationGraph::FindRoot, this ensures
- // that we do not experience pathological performance on graphs
- // such as chains.
+ // Merge the sets by setting the parent pointer of the smaller tree's root
+ // node to point to the root of the larger tree. Together with path
+ // compression in ColocationGraph::FindRoot, this ensures that we do not
+ // experience pathological performance on graphs such as chains.
int new_root, old_root;
if (x_root_member.rank < y_root_member.rank) {
// The tree rooted at x_root is shallower, so connect it to
@@ -611,22 +604,16 @@ class ColocationGraph {
// given id is connected.
int FindRoot(int node_id) {
Member& member = members_[node_id];
-
- int parent = member.parent;
- DCHECK_GE(parent, 0);
-
- if (parent != node_id) {
- // NOTE: Compress paths from node_id to its root, so that future
- // calls to FindRoot and ColocateNodes are more efficient.
- int root = FindRoot(parent);
- if (parent != root) {
- parent = root;
- member.parent = root;
- }
+ DCHECK_GE(member.parent, 0);
+ if (member.parent == node_id) {
+ // member.parent is the root of this disjoint tree. Do nothing.
+ } else {
+ member.parent = FindRoot(member.parent);
}
-
- DCHECK_GE(parent, 0);
- return parent;
+ // Now it is guaranteed that member.parent is the root of this disjoint
+ // tree.
+ DCHECK_GE(member.parent, 0);
+ return member.parent;
}
// Ensures that the devices of 'dst's resource and reference match the device
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
new file mode 100644
index 0000000000..b931ef4229
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_ref.cc
@@ -0,0 +1,170 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/session_ref.h"
+
+#include <utility>
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+} // namespace
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Run(run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Create(graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Create(run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Extend(run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Extend(graph);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status = session_->Close(run_options);
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status = session_->Close();
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Run(inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->ListDevices(response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->PRun(handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->MakeCallable(callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->ReleaseCallable(handle);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h
new file mode 100644
index 0000000000..9459e7edbe
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_ref.h
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
+//
+// SessionRef blocks the return of Close() until all pending operations have
+// been completed or cancelled and underlying session has been freed. Any
+// subsequent operations on the SessionRef object will return errors::Cancelled.
+class SessionRef : public Session {
+ public:
+ SessionRef(Session* session) : session_(session) {}
+ virtual ~SessionRef() {}
+
+ Status Create(const GraphDef& graph) override;
+ Status Extend(const GraphDef& graph) override;
+ Status Create(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status ListDevices(std::vector<DeviceAttributes>* response) override;
+
+ Status Close() override;
+ Status Close(const RunOptions& run_options) override;
+
+ Status Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
+
+ Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
+
+ Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) override;
+
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) override;
+
+ Status ReleaseCallable(CallableHandle handle) override;
+
+ private:
+ mutex run_lock_;
+ condition_variable run_finished_;
+ uint64 run_count_ GUARDED_BY(run_lock_) = {0};
+ std::shared_ptr<Session> session_;
+
+ Status CheckNotClosed();
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index e8ea904ebd..0bd79366eb 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -86,7 +86,8 @@ string AttrSlice::SummarizeNode() const {
string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
string SummarizeNodeDef(const NodeDef& node_def) {
- string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+ string ret = strings::StrCat(FormatNodeDefForError(node_def), " = ",
+ node_def.op(), "[");
strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
strings::StrAppend(&ret, "](");
@@ -101,6 +102,14 @@ string SummarizeNodeDef(const NodeDef& node_def) {
return ret;
}
+string FormatNodeForError(const Node& node) {
+ return FormatNodeDefForError(node.def());
+}
+
+string FormatNodeDefForError(const NodeDef& node_def) {
+ return errors::FormatNodeNameForError(node_def.name());
+}
+
const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
// Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
// requires that the keys used for lookups have type 'const string&'. Because
@@ -634,7 +643,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
Status AttachDef(const Status& status, const NodeDef& node_def) {
Status ret = status;
errors::AppendToMessage(
- &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
+ &ret, strings::StrCat(" [[", SummarizeNodeDef(node_def), "]]"));
return ret;
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 64c8b386e8..c012b7c3d3 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -50,6 +50,12 @@ extern const char* const kColocationGroupPrefix;
string SummarizeNode(const Node& node);
string SummarizeNodeDef(const NodeDef& node_def);
+// Produces a formatted string pattern from the node which can uniquely identify
+// this node upstream to produce an informative error message. The pattern
+// followed is: {{node <node_name>}}
+string FormatNodeForError(const Node& node);
+string FormatNodeDefForError(const NodeDef& node_def);
+
typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 35b7b2272b..74cc594863 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -79,7 +81,7 @@ TEST(NodeDefUtilTest, In) {
)proto");
ExpectSuccess(node_def, op);
- EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
+ EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
// Mismatching Op names.
NodeDef bad = node_def;
@@ -144,7 +146,7 @@ TEST(NodeDefUtilTest, Out) {
)proto");
ExpectSuccess(node_def, op);
- EXPECT_EQ("n = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
+ EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
// Non-number type.
NodeDef bad = node_def;
@@ -164,7 +166,7 @@ TEST(NodeDefUtilTest, Enum) {
)proto");
ExpectSuccess(node_def, op);
- EXPECT_EQ("n = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
+ EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
NodeDef good = node_def;
good.clear_attr();
@@ -191,7 +193,8 @@ TEST(NodeDefUtilTest, SameIn) {
)proto");
ExpectSuccess(node_def, op);
- EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def));
+ EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
+ SummarizeNodeDef(node_def));
// Illegal type
NodeDef bad = ToNodeDef(R"proto(
@@ -220,7 +223,7 @@ TEST(NodeDefUtilTest, AnyIn) {
)proto");
ExpectSuccess(node_def, op);
- EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
+ EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
SummarizeNodeDef(node_def));
const NodeDef bad = ToNodeDef(R"proto(
@@ -243,13 +246,14 @@ TEST(NodeDefUtilTest, Device) {
const NodeDef node_def1 =
ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
ExpectSuccess(node_def1, op_def1);
- EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1));
+ EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
+ SummarizeNodeDef(node_def1));
const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
const NodeDef node_def2 =
ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
ExpectSuccess(node_def2, op_def2);
- EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()",
+ EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
SummarizeNodeDef(node_def2));
}
@@ -284,7 +288,7 @@ TEST(NodeDefUtilTest, ValidSyntax) {
)proto");
ExpectValidSyntax(node_def_explicit_inputs);
- EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
+ EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
SummarizeNodeDef(node_def_explicit_inputs));
const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
@@ -379,7 +383,7 @@ TEST(NameRangesForNodeTest, Simple) {
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
- EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def));
+ EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
OpDef bad_op_def = op_def;
bad_op_def.mutable_input_arg(0)->clear_type();
@@ -399,7 +403,7 @@ TEST(NameRangesForNodeTest, Polymorphic) {
TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
- EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)",
+ EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
SummarizeNodeDef(node_def1));
const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
@@ -408,7 +412,8 @@ TEST(NameRangesForNodeTest, Polymorphic) {
TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
- EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2));
+ EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
+ SummarizeNodeDef(node_def2));
}
TEST(NameRangesForNodeTest, NRepeats) {
@@ -431,7 +436,8 @@ TEST(NameRangesForNodeTest, NRepeats) {
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
outputs);
EXPECT_EQ(
- "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)",
+ "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
+ "b:2, b:3)",
SummarizeNodeDef(node_def1));
const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
@@ -442,7 +448,7 @@ TEST(NameRangesForNodeTest, NRepeats) {
EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
outputs);
- EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
+ EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
SummarizeNodeDef(node_def2));
NodeDef bad_node_def = node_def2;
@@ -471,7 +477,7 @@ TEST(NameRangesForNodeTest, TypeList) {
EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
outputs);
EXPECT_EQ(
- "tl = TypeList[T1=[DT_BOOL, DT_FLOAT],"
+ "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
" T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
" T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
SummarizeNodeDef(node_def1));
@@ -485,7 +491,8 @@ TEST(NameRangesForNodeTest, TypeList) {
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
outputs);
EXPECT_EQ(
- "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32,"
+ "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
+ "DT_INT32,"
" DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
"(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
SummarizeNodeDef(node_def2));
@@ -509,5 +516,20 @@ TEST(AddPrefixAndSuffixToNode, Enter) {
EXPECT_EQ("prefix/test_frame/suffix", frame_name);
}
+TEST(FormatNodeForErrorTest, Node) {
+ Graph g(OpRegistry::Global());
+ Node* node;
+ TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
+ EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
+}
+
+TEST(FormatNodeForErrorTest, NodeDef) {
+ NodeDef node_def;
+ node_def.set_name("enter");
+ node_def.set_op("Enter");
+ AddNodeAttr("frame_name", "test_frame", &node_def);
+ EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index c782480f1f..140f201085 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -209,8 +209,8 @@ TEST_F(OpCompatibilityTest, Same) {
.Finalize(node_def()));
ExpectSuccess(*RegisteredOpDef());
EXPECT_EQ(
- "same = Same[N=3, T=DT_FLOAT, TList=[DT_BOOL, DT_BOOL]](a, b, c, c:1, "
- "c:2, d, d:1, d:2, e, e:1)",
+ "{{node same}} = Same[N=3, T=DT_FLOAT, TList=[DT_BOOL, DT_BOOL]](a, b, "
+ "c, c:1, c:2, d, d:1, d:2, e, e:1)",
Result());
}
@@ -224,7 +224,7 @@ TEST_F(OpCompatibilityTest, AddAttr) {
OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op));
TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op.op_def).Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("add_attr = AddAttr[a=42]()", Result());
+ EXPECT_EQ("{{node add_attr}} = AddAttr[a=42]()", Result());
}
// Should be able to make an attr restriction less strict.
@@ -241,7 +241,7 @@ TEST_F(OpCompatibilityTest, LessStrict) {
.Attr("a", "B")
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("less_strict = LessStrict[a=\"B\"]()", Result());
+ EXPECT_EQ("{{node less_strict}} = LessStrict[a=\"B\"]()", Result());
}
// Should be able to remove an attr restriction.
@@ -259,7 +259,8 @@ TEST_F(OpCompatibilityTest, RemoveRestriction) {
.Attr("a", DT_INT32)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("remove_restriction = RemoveRestriction[a=DT_INT32]()", Result());
+ EXPECT_EQ("{{node remove_restriction}} = RemoveRestriction[a=DT_INT32]()",
+ Result());
}
// Should be able to change the order of attrs.
@@ -278,7 +279,7 @@ TEST_F(OpCompatibilityTest, AttrOrder) {
.Attr("a", 7)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result());
+ EXPECT_EQ("{{node attr_order}} = AttrOrder[a=7, b=true]()", Result());
}
// Should be able to make an input/output polymorphic.
@@ -299,7 +300,8 @@ TEST_F(OpCompatibilityTest, TypePolymorphic) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("type_polymorphic = TypePolymorphic[T=DT_INT32](a)", Result());
+ EXPECT_EQ("{{node type_polymorphic}} = TypePolymorphic[T=DT_INT32](a)",
+ Result());
}
// Should be able to make a single input/output into a list.
@@ -320,7 +322,7 @@ TEST_F(OpCompatibilityTest, MakeList) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("make_list = MakeList[N=1](a)", Result());
+ EXPECT_EQ("{{node make_list}} = MakeList[N=1](a)", Result());
}
// Should be able to make a single input/output into a polymorphic list.
@@ -343,7 +345,8 @@ TEST_F(OpCompatibilityTest, MakePolyList) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("make_poly_list = MakePolyList[N=1, T=DT_INT32](a)", Result());
+ EXPECT_EQ("{{node make_poly_list}} = MakePolyList[N=1, T=DT_INT32](a)",
+ Result());
}
// Should be able to make a single input/output into an arbitrary list.
@@ -364,7 +367,7 @@ TEST_F(OpCompatibilityTest, MakeAnyList) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("make_any_list = MakeAnyList[T=[DT_INT32]](a)", Result());
+ EXPECT_EQ("{{node make_any_list}} = MakeAnyList[T=[DT_INT32]](a)", Result());
}
// Should be able to make a single polymorphic input/output into a list of
@@ -387,7 +390,8 @@ TEST_F(OpCompatibilityTest, PolyIntoList) {
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result());
+ EXPECT_EQ("{{node poly_into_list}} = PolyIntoList[N=1, T=DT_INT32](a)",
+ Result());
}
// Should be able to make a multiple inputs/outputs into a list with
@@ -413,7 +417,7 @@ TEST_F(OpCompatibilityTest, MakeMultipleSameList) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("make_list = MakeMultipleSameList[N=2](a, b)", Result());
+ EXPECT_EQ("{{node make_list}} = MakeMultipleSameList[N=2](a, b)", Result());
}
// Changing from int32, float -> T
@@ -437,8 +441,9 @@ TEST_F(OpCompatibilityTest, MakeMultipleAnyList) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("make_list = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)",
- Result());
+ EXPECT_EQ(
+ "{{node make_list}} = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)",
+ Result());
}
// Should be able to change the name of an input/output.
@@ -455,7 +460,7 @@ TEST_F(OpCompatibilityTest, ChangeName) {
.Input(FakeInput())
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("change_name = ChangeName[](a)", Result());
+ EXPECT_EQ("{{node change_name}} = ChangeName[](a)", Result());
}
// Should be able to add an input/output of type
@@ -473,7 +478,7 @@ TEST_F(OpCompatibilityTest, AddNInts) {
TF_ASSERT_OK(
NodeDefBuilder("add_n_ints", &old_op.op_def).Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("add_n_ints = AddNInts[N=0]()", Result());
+ EXPECT_EQ("{{node add_n_ints}} = AddNInts[N=0]()", Result());
}
// Should be able to add an input/output of type N * T
@@ -492,7 +497,7 @@ TEST_F(OpCompatibilityTest, AddNSame) {
TF_ASSERT_OK(
NodeDefBuilder("add_n_same", &old_op.op_def).Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("add_n_same = AddNSame[N=0, T=DT_BOOL]()", Result());
+ EXPECT_EQ("{{node add_n_same}} = AddNSame[N=0, T=DT_BOOL]()", Result());
}
// Should be able to add an input/output of type N * T
@@ -517,8 +522,10 @@ TEST_F(OpCompatibilityTest, AddNSameAsExisting) {
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("add_n_same_as_existing = AddNSameAsExisting[N=0, T=DT_STRING](a)",
- Result());
+ EXPECT_EQ(
+ "{{node add_n_same_as_existing}} = AddNSameAsExisting[N=0, "
+ "T=DT_STRING](a)",
+ Result());
}
// Should be able to add an input/output of type T
@@ -536,7 +543,7 @@ TEST_F(OpCompatibilityTest, AddAnyList) {
TF_ASSERT_OK(
NodeDefBuilder("add_any_list", &old_op.op_def).Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("add_any_list = AddAnyList[T=[]]()", Result());
+ EXPECT_EQ("{{node add_any_list}} = AddAnyList[T=[]]()", Result());
}
// Should be able to allow shorter lists.
@@ -557,8 +564,10 @@ TEST_F(OpCompatibilityTest, ShorterAnyList) {
.Input(FakeInput(2, DT_BOOL))
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("shorter_any_list = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, a:1)",
- Result());
+ EXPECT_EQ(
+ "{{node shorter_any_list}} = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, "
+ "a:1)",
+ Result());
}
REGISTER_OP("ShorterSameList")
@@ -578,7 +587,8 @@ TEST_F(OpCompatibilityTest, ShorterSameList) {
.Input(FakeInput(2))
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("shorter_same_list = ShorterSameList[N=2](a, a:1)", Result());
+ EXPECT_EQ("{{node shorter_same_list}} = ShorterSameList[N=2](a, a:1)",
+ Result());
}
// Can remove a restriction to an attr
@@ -597,7 +607,7 @@ TEST_F(OpCompatibilityTest, AttrRemoveRestriction) {
.Attr("t", DT_INT32)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("remove_restriction = AttrRemoveRestriction[t=DT_INT32]()",
+ EXPECT_EQ("{{node remove_restriction}} = AttrRemoveRestriction[t=DT_INT32]()",
Result());
}
@@ -619,7 +629,8 @@ TEST_F(OpCompatibilityTest, AttrLessRestrictive) {
.Attr("t", DT_INT32)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("less_restrictive = AttrLessRestrictive[t=DT_INT32]()", Result());
+ EXPECT_EQ("{{node less_restrictive}} = AttrLessRestrictive[t=DT_INT32]()",
+ Result());
}
// Can remove a minimum from an attr.
@@ -637,7 +648,7 @@ TEST_F(OpCompatibilityTest, AttrRemoveMin) {
.Attr("n", 4)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("remove_min = AttrRemoveMin[n=4]()", Result());
+ EXPECT_EQ("{{node remove_min}} = AttrRemoveMin[n=4]()", Result());
}
// Can lower the minimum on an attr.
@@ -655,7 +666,7 @@ TEST_F(OpCompatibilityTest, AttrLowerMin) {
.Attr("n", 4)
.Finalize(node_def()));
ExpectSuccess(old_op.op_def);
- EXPECT_EQ("lower_min = AttrLowerMin[n=4]()", Result());
+ EXPECT_EQ("{{node lower_min}} = AttrLowerMin[n=4]()", Result());
}
// Can make a ref input into a non-ref input.
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 507aa9e447..b53bd8d53d 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -1288,4 +1288,10 @@ void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
SetStatus(s);
}
+void CheckNotInComputeAsync(OpKernelContext* ctx,
+ const char* correct_macro_name) {
+ CHECK_EQ(nullptr, ctx->op_kernel().AsAsync())
+ << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 1fc5e9908e..2b7cc867da 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -113,6 +113,7 @@ class OpKernel {
// Returns nullptr iff this op kernel is synchronous.
virtual AsyncOpKernel* AsAsync() { return nullptr; }
+ virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
// Returns true iff this op kernel is considered "expensive". The
// runtime may use this flag to optimize graph execution for example
@@ -197,6 +198,7 @@ class AsyncOpKernel : public OpKernel {
virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
AsyncOpKernel* AsAsync() final { return this; }
+ const AsyncOpKernel* AsAsync() const final { return this; }
void Compute(OpKernelContext* context) final;
@@ -1542,21 +1544,36 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
// ...
// }
-#define OP_REQUIRES(CTX, EXP, STATUS) \
- do { \
- if (!TF_PREDICT_TRUE(EXP)) { \
- (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
- return; \
- } \
+// Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
+// AsyncOpKernel implementations. If these macros are used and the condition
+// does not hold, the `done` callback will never be called and the system will
+// deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
+// are legal to use in AsyncOpKernel constructors, we use overload resolution
+// to distinguish between OpKernelConstruction* and OpKernelContext* context
+// types.
+class XlaOpKernelContext;
+inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
+inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
+void CheckNotInComputeAsync(OpKernelContext* ctx,
+ const char* correct_macro_name);
+
+#define OP_REQUIRES(CTX, EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
+ CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
+ (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
+ return; \
+ } \
} while (0)
-#define OP_REQUIRES_OK(CTX, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return; \
- } \
+#define OP_REQUIRES_OK(CTX, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return; \
+ } \
} while (0)
#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index 1778e48ef6..8e1e56d29b 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -54,10 +55,11 @@ Status ValidateControlFlowInfo(const Graph* graph,
frame.parent = parent;
frame.name = cf.frame_name;
} else if (frame.parent != parent) {
- return errors::InvalidArgument(
+ return errors::Internal(
"Invalid loop structure: Mismatched parent frames for \"",
cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
- "\". This is an internal bug, please file a bug report with "
+ "\". The node giving this error: ", FormatNodeForError(*node),
+ "This is an internal bug, please file a bug report with "
"instructions on how to reproduce the error.");
}
if (IsLoopCond(node)) {
@@ -69,9 +71,9 @@ Status ValidateControlFlowInfo(const Graph* graph,
!str_util::StrContains(node->name(), "LoopCounter")) {
return errors::InvalidArgument(
"Invalid loop structure: Loop \"", cf.frame_name,
- "\" has more than one LoopCond node: \"", node->name(), "\" and \"",
- frame.loop_cond->name(),
- "\". This is an internal bug, please file a bug report with "
+ "\" has more than one LoopCond node: ", FormatNodeForError(*node),
+ " and ", FormatNodeForError(*frame.loop_cond),
+ ". This is an internal bug, please file a bug report with "
"instructions on how to reproduce the error.");
}
frame.loop_cond = node;
@@ -135,12 +137,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
const string& parent_frame = (*info)[out_parent->id()].frame_name;
if (parent_frame != frame_name) {
return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", parent_frame, "'.");
+ FormatNodeForError(*out),
+ " has inputs from different frames. The input ",
+ FormatNodeForError(*curr_node), " is in frame '", frame_name,
+ "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
+ " is in frame '", parent_frame, "'.");
}
} else {
out_info->frame = out;
@@ -148,7 +149,8 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
TF_RETURN_IF_ERROR(
GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
if (out_info->frame_name.empty()) {
- return errors::InvalidArgument("The Enter node ", out->name(),
+ return errors::InvalidArgument("The Enter ",
+ FormatNodeForError(*out),
" must have a frame name.");
}
}
@@ -156,12 +158,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
if (is_visited) {
if (out_info->frame_name != frame_name) {
return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", out_info->frame_name, "'.");
+ FormatNodeForError(*out),
+ " has inputs from different frames. The input ",
+ FormatNodeForError(*curr_node), " is in frame '", frame_name,
+ "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
+ " is in frame '", out_info->frame_name, "'.");
}
} else {
out_info->frame = frame;
diff --git a/tensorflow/core/graph/control_flow_test.cc b/tensorflow/core/graph/control_flow_test.cc
index eb7937400f..803c757c3f 100644
--- a/tensorflow/core/graph/control_flow_test.cc
+++ b/tensorflow/core/graph/control_flow_test.cc
@@ -63,6 +63,15 @@ TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"has inputs from different frames"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "{{node outer/body/inner/Merge}}"))
+ << status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "{{node outer/body/inner/Enter}}"))
+ << status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node outer/Switch}}"))
+ << status.error_message();
}
TEST(ValidateControlFlowTest, MismatchedParentFrames) {
@@ -102,6 +111,8 @@ TEST(ValidateControlFlowTest, MismatchedParentFrames) {
EXPECT_TRUE(
str_util::StrContains(status.error_message(), "Mismatched parent frames"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Enter2}}"))
+ << status.error_message();
}
TEST(ValidateControlFlowTest, TwoLoopCond) {
@@ -125,6 +136,12 @@ TEST(ValidateControlFlowTest, TwoLoopCond) {
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"more than one LoopCond node"))
<< status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
+ << status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node LoopCond}}"))
+ << status.error_message();
}
} // namespace
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 5f51d6083b..333bf761b0 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
#ifdef INTEL_MKL
-#include <string>
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index b9667998d6..c22e0a3872 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <set>
-#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -2495,13 +2494,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsLRN, LrnRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
- CopyAttrsLRN, LrnRewrite});
+ CopyAttrsLRN, LrnGradRewrite});
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
- CopyAttrsPooling, AlwaysRewrite});
+ CopyAttrsPooling, MaxpoolGradRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
@@ -2887,6 +2886,41 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
+ static bool LrnGradRewrite(const Node* n) {
+ CHECK_NOTNULL(n);
+ bool do_rewrite = false;
+
+ for (const Edge* e : n->in_edges()) {
+ // Rewrite only if there is corresponding LRN, i.e workspace is available
+ if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
+ e->src_output() == 0) {
+ do_rewrite = true;
+ break;
+ }
+ }
+ return do_rewrite;
+ }
+
+ static bool MaxpoolGradRewrite(const Node* n) {
+ CHECK_NOTNULL(n);
+ bool do_rewrite = false;
+ for (const Edge* e : n->in_edges()) {
+ // Rewrite only if there is corresponding Maxpool, i.e workspace is
+ // available
+ if (e->dst()->type_string() == csinfo_.max_pool_grad &&
+ e->dst_input() == 1 &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
+ e->src_output() == 0) {
+ do_rewrite = true;
+ break;
+ }
+ }
+ return do_rewrite;
+ }
+
static bool AddNRewrite(const Node* n) {
CHECK_NOTNULL(n);
@@ -3421,44 +3455,9 @@ Status MklLayoutRewritePass::SetUpInputs(
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
- // We use a tensor of shape {1} and value 0 to represent
- // dummy float tensor. We need this as a dummy workspace tensor.
- // Workspace tensor has type uint8.
- const DataType dt = DataTypeToEnum<uint8>::v();
- TensorProto proto;
- proto.set_dtype(dt);
- float zero[1] = {0};
- proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4));
- TensorShape dummy_shape({1});
- dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
-
- // If number of inputs to the original node is > 0, then we add
- // control dependency between 1st input (index 0) of the original node and
- // the dummy Mkl node. This is needed because control-flow ops such as Enter,
- // Merge, etc, require frame_name of the dummy Mkl node to be same as the
- // rewritten node. Adding control edge between 1st input of the original node
- // and the dummy Mkl node ensures that the dummy node is in the same frame
- // as the original node. Choosing 1st input is not necessary - any input of
- // the original node is fine because all the inputs of a node are always in
- // the same frame.
- if (orig_node->num_inputs() > 0) {
- Node* orig_input0 = nullptr;
- TF_CHECK_OK(
- orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
- // Allow duplicate while adding control edge as it would fail (return
- // NULL) if we try to add duplicate edge.
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
- }
-
- (*out)->set_assigned_device_name(orig_node->assigned_device_name());
+ // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
+ // workspace tensor.
+ GetDummyMklTensorNode(g, out, orig_node);
}
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index fc474c0dc8..a41f5861af 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
-#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
@@ -3015,12 +3014,8 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
- "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+ "A(Input);B(Input);C(Input);D(LRNGrad);"
+ "E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}
/* Test LRN->LRNGrad negative case, where single LRN feeds
@@ -3058,15 +3053,11 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
" input: ['E', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
- "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
- "A:control->DMT/_0:control;B->E:2;"
- "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
- "C:control->DMT/_1:control;C:control->DMT/_2:control;"
- "C:control->DMT/_3:control;C:control->DMT/_4:control;"
- "C:control->DMT/_5:control;C:control->DMT/_6:control;"
- "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;DMT/_3->F:3;"
- "DMT/_4->F:7;DMT/_5->F:4;DMT/_6->F:6;E->G;F->G:1");
+ "DMT/_2(Const);E(_MklLRNGrad);F(LRNGrad);G(Zeta)|A->B;"
+ "A:control->DMT/_0:control;B->E:2;B->F:1;B:1->E:3;B:2->E:6;"
+ "B:3->E:7;C->E;C->F;C:control->DMT/_1:control;"
+ "C:control->DMT/_2:control;D->E:1;D->F:2;DMT/_0->B:1;"
+ "DMT/_1->E:4;DMT/_2->E:5;E->G;F->G:1");
}
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
@@ -3137,12 +3128,8 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
- "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+ "A(Input);B(Input);C(Input);D(MaxPoolGrad);"
+ "E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index e9ced4d2b6..aa39af637f 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <set>
-#include <string>
#include <utility>
#include <vector>
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index bbdbe78bbd..ebcb6de551 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
-#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 6d84283e68..6ca379323e 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -42,6 +42,11 @@ void Cluster::SetNumWarmupSteps(int num_steps) {
num_steps);
}
+// Set executor type to instantiate
+void Cluster::SetExecutorType(const string* executor_type) {
+ options_.config.mutable_experimental()->set_executor_type(*executor_type);
+}
+
int Cluster::NumWarmupSteps() const {
return options_.config.graph_options().build_cost_model_after();
}
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index e94fb900c0..519d5ed875 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -72,6 +72,9 @@ class Cluster {
// before Provision().
void SetNumWarmupSteps(int num_steps);
+ // Set executor type to instantiate
+ void SetExecutorType(const string* executor_type);
+
// Returns the number of warmup steps.
int NumWarmupSteps() const;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 2cb54bd973..65a7f8ccf3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -124,6 +124,7 @@ tf_kernel_library(
":bounds_check",
":dense_update_functor",
":ops_util",
+ ":training_op_helpers",
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc
index c1a25767d3..90762fb1b0 100644
--- a/tensorflow/core/kernels/cwise_op_tan.cc
+++ b/tensorflow/core/kernels/cwise_op_tan.cc
@@ -16,7 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER2(UnaryOp, CPU, "Tan", functor::tan, float, double);
+REGISTER4(UnaryOp, CPU, "Tan", functor::tan, float, double, complex64,
+ complex128);
#if GOOGLE_CUDA
REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double);
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index f99dd643f7..d89f1592bd 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -45,6 +45,24 @@ struct FusedBatchNorm;
template <typename Device, typename T, typename U>
struct FusedBatchNormGrad;
+template <bool IsSame, typename Y, typename X, typename T>
+struct CastIfNecessary {
+ static inline void process(
+ Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
+ const CPUDevice& d) {
+ y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
+ }
+};
+
+template <typename Y, typename X, typename T>
+struct CastIfNecessary<true, Y, X, T> {
+ static inline void process(
+ Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
+ const CPUDevice& d) {
+ y.reshape(rest_by_depth).device(d) = x_shifted;
+ }
+};
+
template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x_input,
@@ -125,7 +143,11 @@ struct FusedBatchNorm<CPUDevice, T, U> {
auto x_shifted =
x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec);
- y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
+ // Explicitly checks the types of T and U and only casts x_shifted when
+ // T != U. (Not doing so caused a 35-50% performance slowdown for
+ // some compiler flags.)
+ CastIfNecessary<std::is_same<T, U>::value, decltype(y), decltype(x_shifted),
+ T>::process(y, x_shifted, rest_by_depth, d);
}
};
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 6f490cdc23..d8efb1be3e 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -308,11 +308,9 @@ class MklConcatOp : public OpKernel {
}
if (invoke_eigen) {
- string msg = std::string("Invoking Eigen version of Concat. Reason:") +
- (!is_concat_dim_channel
- ? std::string("Concat dimension is not channel")
- : std::string("Not all tensors are in Mkl layout"));
- VLOG(1) << "_MklConcatOp: " << msg;
+ VLOG(1) << "_MklConcatOp: Invoking Eigen version of Concat. Reason:"
+ << (!is_concat_dim_channel ? "Concat dimension is not channel"
+ : "Not all tensors are in Mkl layout");
CallEigenVersion(context, input_tensors, input_shapes);
return;
}
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index a370037d97..b73a119a88 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -328,9 +328,8 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(
- const MklConvBwdFilterParams& convBwdFilterDims) {
- std::string prefix = "conv2d_bwd_filter";
+ static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
+ string prefix = "conv2d_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -346,13 +345,13 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
MklPrimitive* GetConv2dBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
- std::string key = CreateKey(convBwdFilterDims);
+ string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
void SetConv2dBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
- std::string key = CreateKey(convBwdFilterDims);
+ string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index b0f7faaa1a..39498f1a80 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -265,9 +265,8 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(
- const MklConvBwdInputParams& convBwdInputDims) {
- std::string prefix = "conv2d_bwd_input";
+ static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
+ string prefix = "conv2d_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -282,13 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
MklPrimitive* GetConv2dBwdInput(
const MklConvBwdInputParams& convBwdInputDims) {
- std::string key = CreateKey(convBwdInputDims);
+ string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
void SetConv2dBwdInput(
const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
- std::string key = CreateKey(convBwdInputDims);
+ string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index b568973220..62396eeb8b 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <string.h>
#include <map>
-#include <string>
#include <vector>
#include <memory>
@@ -35,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
@@ -298,8 +298,8 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(const MklConvFwdParams& convFwdDims) {
- std::string prefix = "conv2d_fwd_";
+ static string CreateKey(const MklConvFwdParams& convFwdDims) {
+ string prefix = "conv2d_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -314,12 +314,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
- std::string key = CreateKey(convFwdDims);
+ string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
- std::string key = CreateKey(convFwdDims);
+ string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
};
@@ -930,10 +930,9 @@ class MklConv2DOp : public OpKernel {
conv2d_fwd->Execute(src_data, filter_data, dst_data);
}
} catch (mkldnn::error &e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + std::string(e.message) +
- ", in file " + std::string(__FILE__) + ":" +
- std::to_string(__LINE__);
+ string error_msg = tensorflow::strings::StrCat(
+ "Status: ", e.status, ", message: ", string(e.message), ", in file ",
+ __FILE__, ":", __LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:", error_msg));
}
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 5e1a5001dc..3f154ff33b 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
#include <limits>
-#include <string>
#include <vector>
#include <memory>
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index c0dfed7d7d..cb1eecb36a 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
#ifdef INTEL_MKL
-#include <string>
#include <vector>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 02ea9fc068..9c536df215 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -152,8 +152,12 @@ class MklReshapeOp : public OpKernel {
// If Tensorflow's data format and the underlying format maintained by
// MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
// safely return true.
+ // @todo: Future do not force skip reorder for all blocked format. Use
+ // blocking_desc_is_equal() for checking all the stride arrays in
+ // mkl-dnn/blob/master/src/common/type_helpers.hpp
auto input_mkl_md = mkl_shape_input.GetMklLayout();
- if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
+ if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format &&
+ mkl_shape_input.GetTfDataFormat() != memory::format::blocked) {
ret = true;
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index f59843a07a..c7d0d4de0d 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
- int num_boxes, const Tensor& max_output_size,
- const float score_threshold,
- std::function<bool(int, int)> suppress_check_fn) {
+void DoNonMaxSuppressionOp(
+ OpKernelContext* context, const Tensor& scores, int num_boxes,
+ const Tensor& max_output_size, const float score_threshold,
+ const std::function<bool(int, int)>& suppress_check_fn,
+ bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
std::vector<float> scores_data(num_boxes);
@@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
}
}
+ int num_valid_outputs = selected.size();
+ if (pad_to_max_output_size) {
+ selected.resize(output_size, 0);
+ selected_scores.resize(output_size, 0);
+ }
+ if (ptr_num_valid_outputs) {
+ *ptr_num_valid_outputs = num_valid_outputs;
+ }
+
// Allocate output tensors
Tensor* output_indices = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
@@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel {
}
};
-template <typename Device>
-class NonMaxSuppressionV3Op : public OpKernel {
+class NonMaxSuppressionV3V4Base : public OpKernel {
public:
- explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
- const Tensor& boxes = context->input(0);
+ boxes_ = context->input(0);
// scores: [num_boxes]
- const Tensor& scores = context->input(1);
+ scores_ = context->input(1);
// max_output_size: scalar
- const Tensor& max_output_size = context->input(2);
+ max_output_size_ = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ context, TensorShapeUtils::IsScalar(max_output_size_.shape()),
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
- max_output_size.shape().DebugString()));
+ max_output_size_.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
- const float iou_threshold_val = iou_threshold.scalar<float>()();
-
+ iou_threshold_val_ = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(score_threshold.shape()),
errors::InvalidArgument("score_threshold must be 0-D, got shape ",
score_threshold.shape().DebugString()));
- const float score_threshold_val = score_threshold.scalar<float>()();
+ score_threshold_val_ = score_threshold.scalar<float>()();
- OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, &num_boxes);
- CheckScoreSizes(context, num_boxes, scores);
+ num_boxes_ = 0;
+ ParseAndCheckBoxSizes(context, boxes_, &num_boxes_);
+ CheckScoreSizes(context, num_boxes_, scores_);
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoComputeAndPostProcess(context);
+ }
+
+ protected:
+ virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0;
+
+ Tensor boxes_;
+ Tensor scores_;
+ Tensor max_output_size_;
+ int num_boxes_;
+ float iou_threshold_val_;
+ float score_threshold_val_;
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {}
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
template <typename Device>
+class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ int num_valid_outputs;
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
+
+ // Allocate scalar output tensor for number of indices computed.
+ Tensor* num_outputs_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, tensorflow::TensorShape{}, &num_outputs_t));
+ num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+template <typename Device>
class NonMaxSuppressionWithOverlapsOp : public OpKernel {
public:
explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
@@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice>);
+
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
NonMaxSuppressionWithOverlapsOp<CPUDevice>);
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index 055161a35f..c321849f40 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -570,6 +570,61 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
}
//
+// NonMaxSuppressionV4Op Tests
+//
+
+class NonMaxSuppressionV4OpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV4")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("pad_to_max_output_size", true)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFive) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {5});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 5, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(3);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFiveScoreThr) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {6});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.4f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 0, 0, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(2);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+//
// NonMaxSuppressionWithOverlapsOp Tests
//
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index c5292e1ae1..cab9eb729d 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -213,64 +213,32 @@ class AssignVariableOp : public OpKernel {
"Variable and value dtypes don't match; respectively, ",
dtype_, " and ", context->input(1).dtype()));
Var* variable = nullptr;
- OP_REQUIRES_OK(
- context,
- LookupOrCreateResource<Var>(
- context, HandleFromInput(context, 0), &variable,
- [this, context](Var** ptr) {
- *ptr = new Var(dtype_);
- PersistentTensor unused;
- Tensor* tmp;
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
- TF_RETURN_IF_ERROR(context->allocate_persistent(
- dtype_, context->input(1).shape(), &unused, &tmp, attr));
- *(*ptr)->tensor() = *tmp;
- return Status::OK();
- }));
+ const Tensor& value = context->input(1);
+ // Note: every resource-variable-manipulating op assumes copy-on-write
+ // semantics, and creates a copy of the variable's Tensor if its refcount is
+ // bigger than 1 when we try to modify it. This means we never need to copy
+ // the original tensor for AssignVariableOp; even if there are other live
+ // users of it we know none can modify it so this is always safe (even in
+ // esoteric cases where the same tensor is used to initialize multiple
+ // variables or the tensor is a constant this is safe, as future writes will
+ // trigger copies).
+ OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, &value](Var** ptr) {
+ *ptr = new Var(dtype_);
+ *(*ptr)->tensor() = value;
+ (*ptr)->is_initialized = true;
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
-
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
-
- const Tensor& value = context->input(1);
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
-
- // Copying is unnecessary if we are the last user of the value
- // tensor, we can just adopt the input tensor's buffer instead.
- std::unique_ptr<Tensor> input_alias = context->forward_input(
- 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_,
- value.shape(), DEVICE_MEMORY, attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
- if (input_alias) {
- *variable->tensor() = *input_alias;
- return;
- }
-
- // Need to copy, but maybe we can re-use variable's buffer?
- if (!variable->tensor()->RefCountIsOne() ||
- !variable->tensor()->shape().IsSameSize(value.shape())) {
- // Copy to new buffer
- PersistentTensor unused;
- Tensor* tmp;
- OP_REQUIRES_OK(context, context->allocate_persistent(
- dtype_, value.shape(), &unused, &tmp, attr));
- *variable->tensor() = *tmp;
- }
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(),
- value.flat<T>());
+ *variable->tensor() = value;
}
private:
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index 990bd2bff9..7930ce4615 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -226,43 +227,53 @@ void RestoreTensor(OpKernelContext* context,
#undef READER_COPY
}
-Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
- const Tensor& tensor_names,
- const Tensor& shape_and_slices,
- gtl::ArraySlice<DataType> dtypes) {
- const string& prefix_string = prefix.scalar<string>()();
+namespace {
- const auto& tensor_names_flat = tensor_names.flat<string>();
- const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+// Tensors larger than this threshold will be restored from a thread-pool.
+const int64 kLargeShapeThreshold = 16 << 20; // 16M
- // Sort lookup keys to improve locality when reading multiple tensors.
- std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
- std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
- std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
- [&tensor_names_flat](size_t a, size_t b) {
- return tensor_names_flat(a) < tensor_names_flat(b);
- });
+// A restore operation for a single tensor. Small tensors may be restored
+// directly from the op thread to improve read locality. Large tensors can be
+// restored from a thread pool: this requires creating a separate BundleReader
+// for each restore.
+struct RestoreOp {
+ RestoreOp& operator=(const RestoreOp&) = delete;
- BundleReader reader(Env::Default(), prefix_string);
- TF_RETURN_IF_ERROR(reader.status());
+ bool should_run_in_pool(BundleReader* reader) const {
+ TensorShape restored_full_shape;
- // TODO(zongheng): potential optimization: one Seek() in first lookup.
- // TODO(zongheng): consider measuring speed and issuing concurrent lookups
- // within a fixed memory budget.
- TensorShape restored_full_shape;
- Tensor* restored_tensor = nullptr;
- for (auto i : sorted_name_idx) {
- const string& tensor_name = tensor_names_flat(i);
- const string& shape_and_slice = shape_and_slices_flat(i);
+ // Ignore status here; we'll catch the error later.
+ if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
+ return false;
+ }
+ return restored_full_shape.num_elements() > kLargeShapeThreshold;
+ }
+
+ // Run this restore operation using a new BundleReader.
+ void run_with_new_reader() {
+ BundleReader reader(Env::Default(), reader_prefix);
+ if (!reader.status().ok()) {
+ status = reader.status();
+ return;
+ }
+
+ status = run(&reader);
+ }
+
+ Status run(BundleReader* reader) {
+ TensorShape restored_full_shape;
TF_RETURN_IF_ERROR(
- reader.LookupTensorShape(tensor_name, &restored_full_shape));
+ reader->LookupTensorShape(tensor_name, &restored_full_shape));
+ VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
+ << restored_full_shape.num_elements();
+ Tensor* restored_tensor;
if (shape_and_slice.empty()) {
// Lookup the full tensor.
TF_RETURN_IF_ERROR(
- context->allocate_output(i, restored_full_shape, &restored_tensor));
- TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor));
+ context->allocate_output(idx, restored_full_shape, &restored_tensor));
+ TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
} else {
// Lookup the slice.
TensorShape parsed_full_shape;
@@ -272,6 +283,7 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
TF_RETURN_IF_ERROR(
checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
&parsed_slice, &parsed_slice_shape));
+
if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
return errors::InvalidArgument(
"tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
@@ -279,19 +291,93 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
" does not match the shape stored in checkpoint: ",
restored_full_shape.DebugString());
}
-
TF_RETURN_IF_ERROR(
- context->allocate_output(i, parsed_slice_shape, &restored_tensor));
+ context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
TF_RETURN_IF_ERROR(
- reader.LookupSlice(tensor_name, parsed_slice, restored_tensor));
+ reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
+ }
+ return Status::OK();
+ }
+
+ OpKernelContext* context;
+ size_t idx;
+ string tensor_name;
+ string shape_and_slice;
+ string reader_prefix;
+
+ ::tensorflow::Status status;
+};
+
+} // namespace
+
+Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
+ const Tensor& tensor_names,
+ const Tensor& shape_and_slices,
+ gtl::ArraySlice<DataType> dtypes) {
+ const string& prefix_string = prefix.scalar<string>()();
+
+ const auto& tensor_names_flat = tensor_names.flat<string>();
+ const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+
+ // Sort lookup keys to improve locality when reading multiple tensors.
+ std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
+ std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
+ std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
+ [&tensor_names_flat](size_t a, size_t b) {
+ return tensor_names_flat(a) < tensor_names_flat(b);
+ });
+
+ std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
+ std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
+
+ BundleReader default_reader(Env::Default(), prefix_string);
+ TF_RETURN_IF_ERROR(default_reader.status());
+
+ for (auto i : sorted_name_idx) {
+ const string& tensor_name = tensor_names_flat(i);
+ const string& shape_and_slice = shape_and_slices_flat(i);
+ auto op =
+ new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
+ if (op->should_run_in_pool(&default_reader)) {
+ pool_restore_ops.emplace_back(op);
+ } else {
+ direct_restore_ops.emplace_back(op);
}
- if (dtypes[i] != restored_tensor->dtype()) {
+ }
+
+ {
+ // Schedule any threaded operations first, skipping thread pool creation if
+ // we don't have any expensive operations.
+ std::unique_ptr<thread::ThreadPool> reader_pool;
+ if (!pool_restore_ops.empty()) {
+ reader_pool.reset(
+ new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
+ for (auto& op : pool_restore_ops) {
+ reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
+ }
+ }
+
+ // Read small tensors from the op thread
+ for (auto& op : direct_restore_ops) {
+ TF_RETURN_IF_ERROR(op->run(&default_reader));
+ }
+ }
+
+ // Check status of pool ops; this must come after the pool shuts down.
+ for (auto& op : pool_restore_ops) {
+ TF_RETURN_IF_ERROR(op->status);
+ }
+
+ for (auto i : sorted_name_idx) {
+ const string& tensor_name = tensor_names_flat(i);
+ if (dtypes[i] != context->mutable_output(i)->dtype()) {
return errors::InvalidArgument(
"tensor_name = ", tensor_name, "; expected dtype ",
DataTypeString(dtypes[i]), " does not equal restored dtype ",
- DataTypeString(restored_tensor->dtype()));
+ DataTypeString(context->mutable_output(i)->dtype()));
}
}
+
return Status::OK();
}
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 1e3e92a68a..59fdc2262a 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -304,6 +305,9 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ mutex_lock ml(*v->mu());
+ OP_REQUIRES_OK(context,
+ PrepareToUpdateVariable<Device, T>(context, v->tensor()));
old_lhs = *v->tensor();
OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index f288e124ee..d3c4f62071 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -39,8 +39,15 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
// GetInputTensor which will signal a failure.
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
+ bool any_resource = false;
+ for (auto i : input_ids) {
+ if (ctx->input_dtype(i) == DT_RESOURCE) {
+ any_resource = true;
+ break;
+ }
+ }
std::vector<mutex_lock> locks;
- if (!do_lock) {
+ if (!do_lock && !any_resource) {
return locks;
}
std::vector<mutex*> mutexes;
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 7e56e15450..765335d3a0 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -80,18 +80,8 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
Var* var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
core::ScopedUnref unref_var(var);
- if (lock_held) {
- TF_RETURN_IF_ERROR(
- PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
- *out = *var->tensor();
- } else {
- mutex_lock ml(*var->mu());
- if (!sparse) {
- TF_RETURN_IF_ERROR(
- PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
- }
- *out = *var->tensor();
- }
+ TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
+ *out = *var->tensor();
return Status::OK();
}
*out = ctx->mutable_input(input, lock_held);
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 51c09032df..a631d9815a 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <sstream>
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -118,6 +119,25 @@ DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED)
#undef DECLARE_ERROR
+// Produces a formatted string pattern from the name which can uniquely identify
+// this node upstream to produce an informative error message. The pattern
+// followed is: {{node <name>}}
+// Note: The pattern below determines the regex _NODEDEF_NAME_RE in the file
+// tensorflow/python/client/session.py
+// LINT.IfChange
+inline string FormatNodeNameForError(const string& name) {
+ return strings::StrCat("{{node ", name, "}}");
+}
+// LINT.ThenChange(//tensorflow/python/client/session.py)
+template <typename T>
+string FormatNodeNamesForError(const T& names) {
+ ::tensorflow::str_util::Formatter<string> f(
+ [](string* output, const string& s) {
+ ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
+ });
+ return ::tensorflow::str_util::Join(names, ", ", f);
+}
+
// The CanonicalCode() for non-errors.
using ::tensorflow::error::OK;
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 4ac8e15160..267af8b976 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -35470,6 +35470,44 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 50ced1ff73..31267f72b8 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -442,8 +442,9 @@ REGISTER_OP("DrawBoundingBoxes")
if (c->ValueKnown(c->Dim(images, 3))) {
int64 depth = c->Value(c->Dim(images, 3));
if (!(depth == 1 || depth == 3 || depth == 4)) {
- return errors::InvalidArgument("Channel depth should be either 1 (GRY), "
- "3 (RGB), or 4 (RGBA)");
+ return errors::InvalidArgument(
+ "Channel depth should be either 1 (GRY), "
+ "3 (RGB), or 4 (RGBA)");
}
}
@@ -709,6 +710,40 @@ REGISTER_OP("NonMaxSuppressionV3")
return Status::OK();
});
+REGISTER_OP("NonMaxSuppressionV4")
+ .Input("boxes: float")
+ .Input("scores: float")
+ .Input("max_output_size: int32")
+ .Input("iou_threshold: float")
+ .Input("score_threshold: float")
+ .Output("selected_indices: int32")
+ .Output("valid_outputs: int32")
+ .Attr("pad_to_max_output_size: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->MakeShape({}));
+ return Status::OK();
+ });
+
REGISTER_OP("NonMaxSuppressionWithOverlaps")
.Input("overlaps: float")
.Input("scores: float")
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 77697756c4..1667c398f4 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -122,6 +122,7 @@ REGISTER_OP("_HostCast")
.Output("y: DstT")
.Attr("SrcT: type")
.Attr("DstT: type")
+ .Attr("Truncate: bool = false")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Cast x of type SrcT to y of DstT.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 22a2f423c2..7973be88e0 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -17007,6 +17007,44 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 09029a4b25..3a012c23fd 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -58,3 +58,9 @@ def if_static(extra_deps, otherwise=[]):
str(Label("//tensorflow:framework_shared_object")): otherwise,
"//conditions:default": extra_deps,
})
+
+def if_dynamic_kernels(extra_deps, otherwise=[]):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index c461a40086..305a9a682f 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -86,7 +86,7 @@ TEST_F(DefaultEnvTest, IncompleteReadOutOfRange) {
TEST_F(DefaultEnvTest, ReadFileToString) {
for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1,
- 1 << 20, (1 << 20) + 1}) {
+ 1 << 20, (1 << 20) + 1, (256 << 20) + 100}) {
const string filename = strings::StrCat(BaseDir(), "/bar/..//file", length);
// Write a file with the given length
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index d701ce8e12..da3a99565e 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -393,6 +393,10 @@ message ConfigProto {
// Whether the client will format templated errors. For example, the string:
// "The node was defined on ^^node:Foo:${file}:${line}^^".
bool client_handles_error_formatting = 2;
+
+ // Which executor to use, the default executor will be used
+ // if it is an empty string or "DEFAULT"
+ string executor_type = 3;
};
Experimental experimental = 16;
diff --git a/tensorflow/core/util/equal_graph_def_test.cc b/tensorflow/core/util/equal_graph_def_test.cc
index c54540332e..77ca8eaec3 100644
--- a/tensorflow/core/util/equal_graph_def_test.cc
+++ b/tensorflow/core/util/equal_graph_def_test.cc
@@ -85,7 +85,7 @@ TEST_F(EqualGraphDefTest, NoMatch) {
Input(e_.opts().WithName("A"));
Input(a_.opts().WithName("B"));
EXPECT_FALSE(Match());
- EXPECT_EQ("Did not find expected node 'A = Input[]()'", diff_);
+ EXPECT_EQ("Did not find expected node '{{node A}} = Input[]()'", diff_);
}
TEST_F(EqualGraphDefTest, MissingNode) {
@@ -93,7 +93,7 @@ TEST_F(EqualGraphDefTest, MissingNode) {
Input(e_.opts().WithName("B"));
Input(a_.opts().WithName("A"));
EXPECT_FALSE(Match());
- EXPECT_EQ("Did not find expected node 'B = Input[]()'", diff_);
+ EXPECT_EQ("Did not find expected node '{{node B}} = Input[]()'", diff_);
}
TEST_F(EqualGraphDefTest, ExtraNode) {
@@ -101,7 +101,7 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
Input(a_.opts().WithName("A"));
Input(a_.opts().WithName("B"));
EXPECT_FALSE(Match());
- EXPECT_EQ("Found unexpected node 'B = Input[]()'", diff_);
+ EXPECT_EQ("Found unexpected node '{{node B}} = Input[]()'", diff_);
}
TEST_F(EqualGraphDefTest, NodeOrder) {
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index bb447e0393..566a42dbd5 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
-#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
@@ -1896,16 +1895,17 @@ class MklPrimitiveFactory {
MklPrimitiveFactory() {}
~MklPrimitiveFactory() {}
- MklPrimitive* GetOp(const std::string& key) {
+ MklPrimitive* GetOp(const string& key) {
auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
if (stream_iter == MklPrimitiveFactory<T>::GetHashMap().end()) {
return nullptr;
} else {
+ CHECK(stream_iter->second != nullptr) << "nullptr present in map";
return stream_iter->second;
}
}
- void SetOp(const std::string& key, MklPrimitive* op) {
+ void SetOp(const string& key, MklPrimitive* op) {
auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
CHECK(stream_iter == MklPrimitiveFactory<T>::GetHashMap().end());
@@ -1914,8 +1914,8 @@ class MklPrimitiveFactory {
}
private:
- static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
- static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
+ static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
+ static thread_local std::unordered_map<string, MklPrimitive*> map_;
return map_;
}
};
@@ -1943,9 +1943,7 @@ class FactoryKeyCreator {
Append(StringPiece(buffer, sizeof(T)));
}
- std::string GetKey() {
- return key_;
- }
+ string GetKey() { return key_; }
private:
string key_;
@@ -2020,8 +2018,8 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
MklReorderPrimitiveFactory() {};
~MklReorderPrimitiveFactory() {};
- static std::string CreateKey(const memory* from, const memory* to) {
- std::string prefix = "reorder";
+ static string CreateKey(const memory* from, const memory* to) {
+ string prefix = "reorder";
FactoryKeyCreator key_creator;
auto const &from_desc = from->get_primitive_desc().desc().data;
auto const &to_desc = to->get_primitive_desc().desc().data;
@@ -2038,12 +2036,12 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetReorder(const memory* from, const memory* to) {
- std::string key = CreateKey(from, to);
+ string key = CreateKey(from, to);
return this->GetOp(key);
}
void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
- std::string key = CreateKey(from, to);
+ string key = CreateKey(from, to);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/docs_src/deploy/distributed.md b/tensorflow/docs_src/deploy/distributed.md
index 8e2c818e39..fc3a60603f 100644
--- a/tensorflow/docs_src/deploy/distributed.md
+++ b/tensorflow/docs_src/deploy/distributed.md
@@ -314,7 +314,7 @@ serve multiple clients.
**Cluster**
-A TensorFlow cluster comprises a one or more "jobs", each divided into lists of
+A TensorFlow cluster comprises one or more "jobs", each divided into lists of
one or more "tasks". A cluster is typically dedicated to a particular high-level
objective, such as training a neural network, using many machines in parallel. A
cluster is defined by
diff --git a/tensorflow/docs_src/guide/using_gpu.md b/tensorflow/docs_src/guide/using_gpu.md
index c429ca4750..c0218fd12e 100644
--- a/tensorflow/docs_src/guide/using_gpu.md
+++ b/tensorflow/docs_src/guide/using_gpu.md
@@ -143,7 +143,7 @@ If the device you have specified does not exist, you will get
```
InvalidArgumentError: Invalid argument: Cannot assign a device to node 'b':
Could not satisfy explicit device specification '/device:GPU:2'
- [[Node: b = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [3,2]
+ [[{{node b}} = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [3,2]
values: 1 2 3...>, _device="/device:GPU:2"]()]]
```
diff --git a/tensorflow/docs_src/performance/xla/broadcasting.md b/tensorflow/docs_src/performance/xla/broadcasting.md
index eaa709c2f8..7018ded53f 100644
--- a/tensorflow/docs_src/performance/xla/broadcasting.md
+++ b/tensorflow/docs_src/performance/xla/broadcasting.md
@@ -99,7 +99,7 @@ dimensions 1 and 2 of the cuboid.
This type of broadcast is used in the binary ops in `XlaBuilder`, if the
`broadcast_dimensions` argument is given. For example, see
-[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.cc).
+[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.cc).
In the XLA source code, this type of broadcasting is sometimes called "InDim"
broadcasting.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index fe9afc4ecb..5f7482f90f 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1,7 +1,7 @@
# Operation Semantics
The following describes the semantics of operations defined in the
-[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
interface. Typically, these operations map one-to-one to operations defined in
the RPC interface in
[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
@@ -16,7 +16,7 @@ and familiar names; for example a *vector* is a 1-dimensional array and a
## BatchNormGrad
See also
-[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
@@ -80,7 +80,7 @@ The output type is a tuple of three handles:
## BatchNormInference
See also
-[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
@@ -115,7 +115,7 @@ The output is an n-dimensional, normalized array with the same shape as input
## BatchNormTraining
See also
-[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
@@ -167,7 +167,7 @@ spatial dimensions using the formulas above.
## BitcastConvertType
See also
-[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
operation from a data shape to a target shape. The dimensions must match, and
@@ -189,7 +189,7 @@ and destination element types must not be tuples.
## Broadcast
See also
-[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Adds dimensions to an array by duplicating the data in the array.
@@ -217,7 +217,7 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and
## Call
See also
-[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Invokes a computation with the given arguments.
@@ -236,7 +236,7 @@ The arity and types of the `args` must match the parameters of the
## Clamp
See also
-[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Clamps an operand to within the range between a minimum and maximum value.
@@ -269,7 +269,7 @@ Clamp(min, operand, max) = s32[3]{0, 5, 6};
## Collapse
See also
-[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
and the @{tf.reshape} operation.
Collapses dimensions of an array into one dimension.
@@ -332,7 +332,7 @@ then v12 == f32[8x3] {{10, 11, 12},
## Concatenate
See also
-[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Concatenate composes an array from multiple array operands. The array is of the
same rank as each of the input array operands (which must be of the same rank as
@@ -388,7 +388,7 @@ Diagram:
## Conditional
See also
-[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Conditional(pred, true_operand, true_computation, false_operand,
false_computation)` </b>
@@ -416,7 +416,7 @@ executed depending on the value of `pred`.
## Conv (convolution)
See also
-[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
@@ -426,7 +426,7 @@ account. VALID padding simply means no padding.
## ConvWithGeneralPadding (convolution)
See also
-[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Computes a convolution of the kind used in neural networks. Here, a convolution
can be thought of as a n-dimensional window moving across a n-dimensional base
@@ -538,7 +538,7 @@ for (b, oz, oy, ox) { // output coordinates
## ConvertElementType
See also
-[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Similar to an element-wise `static_cast` in C++, performs an element-wise
conversion operation from a data shape to a target shape. The dimensions must
@@ -572,7 +572,7 @@ then b == f32[3]{0.0, 1.0, 2.0}
## CrossReplicaSum
See also
-[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Computes a sum across replicas.
@@ -607,7 +607,7 @@ than another.
## CustomCall
See also
-[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Call a user-provided function within a computation.
@@ -668,7 +668,7 @@ idempotent.
## Dot
See also
-[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Dot(lhs, rhs)` </b>
@@ -697,7 +697,7 @@ multiplications or matrix/matrix multiplications.
## DotGeneral
See also
-[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
@@ -784,7 +784,7 @@ non-contracting/non-batch dimension.
## DynamicSlice
See also
-[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
DynamicSlice extracts a sub-array from the input array at dynamic
`start_indices`. The size of the slice in each dimension is passed in
@@ -848,7 +848,7 @@ DynamicSlice(b, s, {2, 2}) produces:
## DynamicUpdateSlice
See also
-[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
DynamicUpdateSlice generates a result which is the value of the input array
`operand`, with a slice `update` overwritten at `start_indices`.
@@ -920,7 +920,7 @@ DynamicUpdateSlice(b, u, s) produces:
## Element-wise binary arithmetic operations
See also
-[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
A set of element-wise binary arithmetic operations is supported.
@@ -965,7 +965,7 @@ shapes of both operands. The semantics are described in detail on the
## Element-wise comparison operations
See also
-[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
A set of standard element-wise binary comparison operations is supported. Note
that standard IEEE 754 floating-point comparison semantics apply when comparing
@@ -1051,7 +1051,7 @@ potentially different runtime offset) of an input tensor into an output tensor.
### General Semantics
See also
-[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
For a more intuitive description, see the "Informal Description" section below.
<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
@@ -1254,7 +1254,7 @@ concatenation of all these rows.
## GetTupleElement
See also
-[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Indexes into a tuple with a compile-time-constant value.
@@ -1275,7 +1275,7 @@ See also @{tf.tuple}.
## Infeed
See also
-[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Infeed(shape)` </b>
@@ -1327,7 +1327,7 @@ Arguments | Type | Semantics
## Map
See also
-[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Map(operands..., computation)` </b>
@@ -1356,7 +1356,7 @@ input arrays to produce the output array.
## Pad
See also
-[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Pad(operand, padding_value, padding_config)` </b>
@@ -1395,7 +1395,7 @@ are all 0. The figure below shows examples of different `edge_padding` and
## Recv
See also
-[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Recv(shape, channel_handle)` </b>
@@ -1429,7 +1429,7 @@ complete and returns the received data.
## Reduce
See also
-[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Applies a reduction function to an array.
@@ -1546,7 +1546,7 @@ Reducing the 3D array over all its dimensions produces the scalar `84`.
## ReducePrecision
See also
-[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Models the effect of converting floating-point values to a lower-precision
format (such as IEEE-FP16) and back to the original format. The number of
@@ -1577,7 +1577,7 @@ portion of the conversion is then simply a no-op.
## ReduceWindow
See also
-[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Applies a reduction function to all elements in each window of the input
multi-dimensional array, producing an output multi-dimensional array with the
@@ -1660,7 +1660,7 @@ context of [`Reduce`](#reduce) for more details.
## Reshape
See also
-[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
and the [`Collapse`](#collapse) operation.
Reshapes the dimensions of an array into a new configuration.
@@ -1741,7 +1741,7 @@ Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
## Rev (reverse)
See also
-[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b>`Rev(operand, dimensions)`</b>
@@ -1763,7 +1763,7 @@ the two window dimensions during the gradient computation in neural networks.
## RngNormal
See also
-[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Constructs an output of a given shape with random numbers generated following
the $$N(\mu, \sigma)$$ normal distribution. The parameters `mu` and `sigma`, and
@@ -1783,7 +1783,7 @@ be scalar valued.
## RngUniform
See also
-[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Constructs an output of a given shape with random numbers generated following
the uniform distribution over the interval $$[a,b)$$. The parameters and output
@@ -1804,7 +1804,7 @@ is implementation-defined.
## Select
See also
-[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Constructs an output array from elements of two input arrays, based on the
values of a predicate array.
@@ -1855,7 +1855,7 @@ the same shape!) then `pred` has to be a scalar of type `PRED`.
## SelectAndScatter
See also
-[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
This operation can be considered as a composite operation that first computes
`ReduceWindow` on the `operand` array to select an element from each window, and
@@ -1935,7 +1935,7 @@ context of [`Reduce`](#reduce) for more details.
## Send
See also
-[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `Send(operand, channel_handle)` </b>
@@ -1990,7 +1990,7 @@ computations. For example, below schedules lead to deadlocks.
## Slice
See also
-[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Slicing extracts a sub-array from the input array. The sub-array is of the same
rank as the input and contains the values inside a bounding box within the input
@@ -2039,7 +2039,7 @@ Slice(b, {2, 1}, {4, 3}) produces:
## Sort
See also
-[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
There are two versions of the Sort instruction: a single-operand and a
two-operand version.
@@ -2099,7 +2099,7 @@ This is the same as Reshape(operand, permutation,
## Tuple
See also
-[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
A tuple containing a variable number of data handles, each of which has its own
shape.
@@ -2118,7 +2118,7 @@ Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
## While
See also
-[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
<b> `While(condition, body, init)` </b>
diff --git a/tensorflow/examples/saved_model/saved_model_half_plus_two.py b/tensorflow/examples/saved_model/saved_model_half_plus_two.py
index 0d6f1ef655..2d1e0c6f6d 100644
--- a/tensorflow/examples/saved_model/saved_model_half_plus_two.py
+++ b/tensorflow/examples/saved_model/saved_model_half_plus_two.py
@@ -33,6 +33,13 @@ where `a`, `b` and `c` are variables with `a=0.5` and `b=2` and `c=3`.
Output from this program is typically used to exercise SavedModel load and
execution code.
+
+To create a CPU model:
+ bazel run -c opt saved_half_plus_two -- --device=cpu
+
+To create GPU model:
+ bazel run --config=cuda -c opt saved_half_plus_two -- \
+ --device=gpu
"""
from __future__ import absolute_import
@@ -105,42 +112,52 @@ def _build_classification_signature(input_tensor, scores_tensor):
def _generate_saved_model_for_half_plus_two(export_dir,
as_text=False,
- use_main_op=False):
+ use_main_op=False,
+ device_type="cpu"):
"""Generates SavedModel for half plus two.
Args:
export_dir: The directory to which the SavedModel should be written.
as_text: Writes the SavedModel protocol buffer in text format to disk.
use_main_op: Whether to supply a main op during SavedModel build time.
+ device_name: Device to force ops to run on.
"""
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
- with tf.Session(graph=tf.Graph()) as sess:
- # Set up the model parameters as variables to exercise variable loading
- # functionality upon restore.
- a = tf.Variable(0.5, name="a")
- b = tf.Variable(2.0, name="b")
- c = tf.Variable(3.0, name="c")
-
- # Create a placeholder for serialized tensorflow.Example messages to be fed.
- serialized_tf_example = tf.placeholder(tf.string, name="tf_example")
-
- # Parse the tensorflow.Example looking for a feature named "x" with a single
- # floating point value.
- feature_configs = {
- "x": tf.FixedLenFeature(
- [1], dtype=tf.float32),
- "x2": tf.FixedLenFeature(
- [1], dtype=tf.float32, default_value=[0.0])
- }
- tf_example = tf.parse_example(serialized_tf_example, feature_configs)
- # Use tf.identity() to assign name
- x = tf.identity(tf_example["x"], name="x")
- y = tf.add(tf.multiply(a, x), b, name="y")
- y2 = tf.add(tf.multiply(a, x), c, name="y2")
-
- x2 = tf.identity(tf_example["x2"], name="x2")
- y3 = tf.add(tf.multiply(a, x2), c, name="y3")
+ device_name = "/cpu:0"
+ if device_type == "gpu":
+ device_name = "/gpu:0"
+
+ with tf.Session(
+ graph=tf.Graph(),
+ config=tf.ConfigProto(log_device_placement=True)) as sess:
+ with tf.device(device_name):
+ # Set up the model parameters as variables to exercise variable loading
+ # functionality upon restore.
+ a = tf.Variable(0.5, name="a")
+ b = tf.Variable(2.0, name="b")
+ c = tf.Variable(3.0, name="c")
+
+ # Create a placeholder for serialized tensorflow.Example messages to be
+ # fed.
+ serialized_tf_example = tf.placeholder(tf.string, name="tf_example")
+
+ # Parse the tensorflow.Example looking for a feature named "x" with a
+ # single floating point value.
+ feature_configs = {
+ "x": tf.FixedLenFeature([1], dtype=tf.float32),
+ "x2": tf.FixedLenFeature([1], dtype=tf.float32, default_value=[0.0])
+ }
+ # parse_example only works on CPU
+ with tf.device("/cpu:0"):
+ tf_example = tf.parse_example(serialized_tf_example, feature_configs)
+ # Use tf.identity() to assign name
+ x = tf.identity(tf_example["x"], name="x")
+ y = tf.add(tf.multiply(a, x), b, name="y")
+ y2 = tf.add(tf.multiply(a, x), c, name="y2")
+
+ x2 = tf.identity(tf_example["x2"], name="x2")
+ y3 = tf.add(tf.multiply(a, x2), c, name="y3")
# Create an assets file that can be saved and restored as part of the
# SavedModel.
@@ -185,20 +202,7 @@ def _generate_saved_model_for_half_plus_two(export_dir,
}
# Initialize all variables and then save the SavedModel.
sess.run(tf.global_variables_initializer())
- signature_def_map = {
- "regress_x_to_y":
- _build_regression_signature(serialized_tf_example, y),
- "regress_x_to_y2":
- _build_regression_signature(serialized_tf_example, y2),
- "regress_x2_to_y3":
- _build_regression_signature(x2, y3),
- "classify_x_to_y":
- _build_classification_signature(serialized_tf_example, y),
- "classify_x2_to_y3":
- _build_classification_signature(x2, y3),
- tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- predict_signature_def
- }
+
if use_main_op:
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
@@ -212,19 +216,30 @@ def _generate_saved_model_for_half_plus_two(export_dir,
signature_def_map=signature_def_map,
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=tf.group(assign_filename_op))
- builder.save(as_text)
+ builder.save(as_text)
def main(_):
- _generate_saved_model_for_half_plus_two(FLAGS.output_dir)
- print("SavedModel generated at: %s" % FLAGS.output_dir)
+ _generate_saved_model_for_half_plus_two(
+ FLAGS.output_dir, device_type=FLAGS.device)
+ print("SavedModel generated for %(device)s at: %(dir)s" % {
+ "device": FLAGS.device,
+ "dir": FLAGS.output_dir
+ })
- _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True)
- print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt)
+ _generate_saved_model_for_half_plus_two(
+ FLAGS.output_dir_pbtxt, as_text=True, device_type=FLAGS.device)
+ print("SavedModel generated for %(device)s at: %(dir)s" % {
+ "device": FLAGS.device,
+ "dir": FLAGS.output_dir_pbtxt
+ })
_generate_saved_model_for_half_plus_two(
- FLAGS.output_dir_main_op, use_main_op=True)
- print("SavedModel generated at: %s" % FLAGS.output_dir_main_op)
+ FLAGS.output_dir_main_op, use_main_op=True, device_type=FLAGS.device)
+ print("SavedModel generated for %(device)s at: %(dir)s " % {
+ "device": FLAGS.device,
+ "dir": FLAGS.output_dir_main_op
+ })
if __name__ == "__main__":
@@ -244,5 +259,10 @@ if __name__ == "__main__":
type=str,
default="/tmp/saved_model_half_plus_two_main_op",
help="Directory where to output the SavedModel with a main op.")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cpu",
+ help="Force model to run on 'cpu' or 'gpu'")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 6c9bf1e714..1e765d1cd7 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -4877,6 +4877,146 @@ func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// AudioSpectrogramAttr is an optional argument to AudioSpectrogram.
+type AudioSpectrogramAttr func(optionalAttr)
+
+// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value.
+//
+// value: Whether to return the squared magnitude or just the
+// magnitude. Using squared magnitude can avoid extra calculations.
+// If not specified, defaults to false
+func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr {
+ return func(m optionalAttr) {
+ m["magnitude_squared"] = value
+ }
+}
+
+// Produces a visualization of audio data over time.
+//
+// Spectrograms are a standard way of representing audio information as a series of
+// slices of frequency information, one slice for each window of time. By joining
+// these together into a sequence, they form a distinctive fingerprint of the sound
+// over time.
+//
+// This op expects to receive audio data as an input, stored as floats in the range
+// -1 to 1, together with a window width in samples, and a stride specifying how
+// far to move the window between slices. From this it generates a three
+// dimensional output. The lowest dimension has an amplitude value for each
+// frequency during that time slice. The next dimension is time, with successive
+// frequency slices. The final dimension is for the channels in the input, so a
+// stereo audio input would have two here for example.
+//
+// This means the layout when converted and saved as an image is rotated 90 degrees
+// clockwise from a typical spectrogram. Time is descending down the Y axis, and
+// the frequency decreases from left to right.
+//
+// Each value in the result represents the square root of the sum of the real and
+// imaginary parts of an FFT on the current window of samples. In this way, the
+// lowest dimension represents the power of each frequency in the current window,
+// and adjacent windows are concatenated in the next dimension.
+//
+// To get a more intuitive and visual look at what this operation does, you can run
+// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
+// resulting spectrogram as a PNG image.
+//
+// Arguments:
+// input: Float representation of audio data.
+// window_size: How wide the input window is in samples. For the highest efficiency
+// this should be a power of two, but other values are accepted.
+// stride: How widely apart the center of adjacent sample windows should be.
+//
+// Returns 3D representation of the audio frequencies as an image.
+func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"window_size": window_size, "stride": stride}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AudioSpectrogram",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// CTCBeamSearchDecoderAttr is an optional argument to CTCBeamSearchDecoder.
+type CTCBeamSearchDecoderAttr func(optionalAttr)
+
+// CTCBeamSearchDecoderMergeRepeated sets the optional merge_repeated attribute to value.
+//
+// value: If true, merge repeated classes in output.
+// If not specified, defaults to true
+func CTCBeamSearchDecoderMergeRepeated(value bool) CTCBeamSearchDecoderAttr {
+ return func(m optionalAttr) {
+ m["merge_repeated"] = value
+ }
+}
+
+// Performs beam search decoding on the logits given in input.
+//
+// A note about the attribute merge_repeated: For the beam search decoder,
+// this means that if consecutive entries in a beam are the same, only
+// the first of these is emitted. That is, when the top path is "A B B B B",
+// "A B" is returned if merge_repeated = True but "A B B B B" is
+// returned if merge_repeated = False.
+//
+// Arguments:
+// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
+// sequence_length: A vector containing sequence lengths, size `(batch)`.
+// beam_width: A scalar >= 0 (beam search beam width).
+// top_paths: A scalar >= 0, <= beam_width (controls output size).
+//
+// Returns A list (length: top_paths) of indices matrices. Matrix j,
+// size `(total_decoded_outputs[j] x 2)`, has indices of a
+// `SparseTensor<int64, 2>`. The rows store: [batch, time].A list (length: top_paths) of values vectors. Vector j,
+// size `(length total_decoded_outputs[j])`, has the values of a
+// `SparseTensor<int64, 2>`. The vector stores the decoded classes for beam j.A list (length: top_paths) of shape vector. Vector j,
+// size `(2)`, stores the shape of the decoded `SparseTensor[j]`.
+// Its values are: `[batch_size, max_decoded_length[j]]`.A matrix, shaped: `(batch_size x top_paths)`. The
+// sequence log-probabilities.
+func CTCBeamSearchDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, beam_width int64, top_paths int64, optional ...CTCBeamSearchDecoderAttr) (decoded_indices []tf.Output, decoded_values []tf.Output, decoded_shape []tf.Output, log_probability tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"beam_width": beam_width, "top_paths": top_paths}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "CTCBeamSearchDecoder",
+ Input: []tf.Input{
+ inputs, sequence_length,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if decoded_indices, idx, err = makeOutputList(op, idx, "decoded_indices"); err != nil {
+ scope.UpdateErr("CTCBeamSearchDecoder", err)
+ return
+ }
+ if decoded_values, idx, err = makeOutputList(op, idx, "decoded_values"); err != nil {
+ scope.UpdateErr("CTCBeamSearchDecoder", err)
+ return
+ }
+ if decoded_shape, idx, err = makeOutputList(op, idx, "decoded_shape"); err != nil {
+ scope.UpdateErr("CTCBeamSearchDecoder", err)
+ return
+ }
+ log_probability = op.Output(idx)
+ return decoded_indices, decoded_values, decoded_shape, log_probability
+}
+
// MatrixInverseAttr is an optional argument to MatrixInverse.
type MatrixInverseAttr func(optionalAttr)
@@ -6029,146 +6169,6 @@ func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, ou
return op.Output(0)
}
-// CTCBeamSearchDecoderAttr is an optional argument to CTCBeamSearchDecoder.
-type CTCBeamSearchDecoderAttr func(optionalAttr)
-
-// CTCBeamSearchDecoderMergeRepeated sets the optional merge_repeated attribute to value.
-//
-// value: If true, merge repeated classes in output.
-// If not specified, defaults to true
-func CTCBeamSearchDecoderMergeRepeated(value bool) CTCBeamSearchDecoderAttr {
- return func(m optionalAttr) {
- m["merge_repeated"] = value
- }
-}
-
-// Performs beam search decoding on the logits given in input.
-//
-// A note about the attribute merge_repeated: For the beam search decoder,
-// this means that if consecutive entries in a beam are the same, only
-// the first of these is emitted. That is, when the top path is "A B B B B",
-// "A B" is returned if merge_repeated = True but "A B B B B" is
-// returned if merge_repeated = False.
-//
-// Arguments:
-// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
-// sequence_length: A vector containing sequence lengths, size `(batch)`.
-// beam_width: A scalar >= 0 (beam search beam width).
-// top_paths: A scalar >= 0, <= beam_width (controls output size).
-//
-// Returns A list (length: top_paths) of indices matrices. Matrix j,
-// size `(total_decoded_outputs[j] x 2)`, has indices of a
-// `SparseTensor<int64, 2>`. The rows store: [batch, time].A list (length: top_paths) of values vectors. Vector j,
-// size `(length total_decoded_outputs[j])`, has the values of a
-// `SparseTensor<int64, 2>`. The vector stores the decoded classes for beam j.A list (length: top_paths) of shape vector. Vector j,
-// size `(2)`, stores the shape of the decoded `SparseTensor[j]`.
-// Its values are: `[batch_size, max_decoded_length[j]]`.A matrix, shaped: `(batch_size x top_paths)`. The
-// sequence log-probabilities.
-func CTCBeamSearchDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, beam_width int64, top_paths int64, optional ...CTCBeamSearchDecoderAttr) (decoded_indices []tf.Output, decoded_values []tf.Output, decoded_shape []tf.Output, log_probability tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"beam_width": beam_width, "top_paths": top_paths}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "CTCBeamSearchDecoder",
- Input: []tf.Input{
- inputs, sequence_length,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if decoded_indices, idx, err = makeOutputList(op, idx, "decoded_indices"); err != nil {
- scope.UpdateErr("CTCBeamSearchDecoder", err)
- return
- }
- if decoded_values, idx, err = makeOutputList(op, idx, "decoded_values"); err != nil {
- scope.UpdateErr("CTCBeamSearchDecoder", err)
- return
- }
- if decoded_shape, idx, err = makeOutputList(op, idx, "decoded_shape"); err != nil {
- scope.UpdateErr("CTCBeamSearchDecoder", err)
- return
- }
- log_probability = op.Output(idx)
- return decoded_indices, decoded_values, decoded_shape, log_probability
-}
-
-// AudioSpectrogramAttr is an optional argument to AudioSpectrogram.
-type AudioSpectrogramAttr func(optionalAttr)
-
-// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value.
-//
-// value: Whether to return the squared magnitude or just the
-// magnitude. Using squared magnitude can avoid extra calculations.
-// If not specified, defaults to false
-func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr {
- return func(m optionalAttr) {
- m["magnitude_squared"] = value
- }
-}
-
-// Produces a visualization of audio data over time.
-//
-// Spectrograms are a standard way of representing audio information as a series of
-// slices of frequency information, one slice for each window of time. By joining
-// these together into a sequence, they form a distinctive fingerprint of the sound
-// over time.
-//
-// This op expects to receive audio data as an input, stored as floats in the range
-// -1 to 1, together with a window width in samples, and a stride specifying how
-// far to move the window between slices. From this it generates a three
-// dimensional output. The lowest dimension has an amplitude value for each
-// frequency during that time slice. The next dimension is time, with successive
-// frequency slices. The final dimension is for the channels in the input, so a
-// stereo audio input would have two here for example.
-//
-// This means the layout when converted and saved as an image is rotated 90 degrees
-// clockwise from a typical spectrogram. Time is descending down the Y axis, and
-// the frequency decreases from left to right.
-//
-// Each value in the result represents the square root of the sum of the real and
-// imaginary parts of an FFT on the current window of samples. In this way, the
-// lowest dimension represents the power of each frequency in the current window,
-// and adjacent windows are concatenated in the next dimension.
-//
-// To get a more intuitive and visual look at what this operation does, you can run
-// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
-// resulting spectrogram as a PNG image.
-//
-// Arguments:
-// input: Float representation of audio data.
-// window_size: How wide the input window is in samples. For the highest efficiency
-// this should be a power of two, but other values are accepted.
-// stride: How widely apart the center of adjacent sample windows should be.
-//
-// Returns 3D representation of the audio frequencies as an image.
-func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"window_size": window_size, "stride": stride}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AudioSpectrogram",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Compute the polygamma function \\(\psi^{(n)}(x)\\).
//
// The polygamma function is defined as:
diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md
index 3e030dcd09..cbc64a284f 100644
--- a/tensorflow/java/maven/README.md
+++ b/tensorflow/java/maven/README.md
@@ -151,16 +151,6 @@ conducted in a [Docker](https://www.docker.com) container.
7. Upon successful release, commit changes to all the `pom.xml` files
(which should have the updated version number).
-### Snapshots
-
-If the `TF_VERSION` provided to the `release.sh` script ends in `-SNAPSHOT`,
-then instead of using official release files, the nightly build artifacts from
-https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/,
-https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ and
-https://ci.tensorflow.org/view/Nightly/job/nightly-android
-will be used to upload to the Maven Central snapshots repository. (Note that
-snapshots are only uploaded to Maven Central, not Bintray.)
-
### Skip deploying to a repository
Should you need, setting environment variables `DEPLOY_OSSRH=0` or
@@ -173,12 +163,12 @@ cannot skip deploying to OSSRH for a `-SNAPSHOT` version.
This section provides some pointers around how artifacts are currently
assembled.
-All native and java code is first built and tested on
-a [Tensorflow Jenkins server](https://ci.tensorflow.org/) which run various
-scripts under the [`tools/ci_build`](../../tools/ci_build/) directory. Of
-particular interest may be `tools/ci_build/builds/libtensorflow.sh` which
-bundles Java-related build sources and outputs into archives, and
-`tools/ci_build/builds/android_full.sh` which produces an Android AAR package.
+All native and java code is first built and tested by the release process
+which run various scripts under the [`tools/ci_build`](../../tools/ci_build/)
+directory. Of particular interest may be
+`tools/ci_build/builds/libtensorflow.sh` which bundles Java-related build
+sources and outputs into archives, and `tools/ci_build/builds/android_full.sh`
+which produces an Android AAR package.
Maven artifacts however are not created in Jenkins. Instead, artifacts are
created and deployed externally on-demand, when a maintainer runs the
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 2240d6b7b9..f4794d68a9 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -26,12 +26,6 @@ TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}"
DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}"
-IS_SNAPSHOT="false"
-if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then
- IS_SNAPSHOT="true"
- # Bintray does not allow snapshots.
- DEPLOY_BINTRAY="false"
-fi
PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip"
if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then
echo "Must deploy to at least one of Bintray or OSSRH" >&2
@@ -69,11 +63,7 @@ mvn_property() {
}
download_libtensorflow() {
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow-src.jar"
- else
- URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar"
- fi
+ URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar"
curl -L "${URL}" -o /tmp/src.jar
cd "${DIR}/libtensorflow"
jar -xvf /tmp/src.jar
@@ -101,17 +91,9 @@ download_libtensorflow_jni() {
mkdir windows-x86_64
mkdir darwin-x86_64
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/
- # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=mac-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz" | tar -xvz -C darwin-x86_64
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-windows-x86_64.zip" -o /tmp/windows.zip
- else
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
- fi
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
unzip /tmp/windows.zip -d windows-x86_64
rm -f /tmp/windows.zip
@@ -129,13 +111,7 @@ download_libtensorflow_jni_gpu() {
mkdir linux-x86_64
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/
- # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=gpu-linux/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64
- else
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
- fi
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
# Updated timestamps seem to be required to get Maven to pick up the file.
touch linux-x86_64/*
@@ -165,11 +141,7 @@ generate_java_protos() {
rm -f "/tmp/protoc.zip"
# Download the release archive of TensorFlow protos.
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_proto.zip"
- else
- URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip"
- fi
+ URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip"
curl -L "${URL}" -o /tmp/libtensorflow_proto.zip
mkdir -p "${DIR}/proto/tmp/src"
unzip -d "${DIR}/proto/tmp/src" "/tmp/libtensorflow_proto.zip"
@@ -238,11 +210,7 @@ deploy_profile() {
# Determine the correct pom file property to use
# for the repository url.
local rtype
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- rtype='snapshotRepository'
- else
- rtype='repository'
- fi
+ rtype='repository'
local url=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.url")
local repositoryId=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.id")
mvn gpg:sign-and-deploy-file \
@@ -300,17 +268,13 @@ mvn verify
deploy_artifacts
set +ex
-if [[ "${IS_SNAPSHOT}" == "false" ]]; then
- echo "Uploaded to the staging repository"
- echo "After validating the release: "
- if [[ "${DEPLOY_OSSRH}" == "true" ]]; then
- echo "* Login to https://oss.sonatype.org/#stagingRepositories"
- echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort"
- fi
- if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then
- echo "* Login to https://bintray.com/google/tensorflow/tensorflow"
- echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort"
- fi
-else
- echo "Uploaded to the snapshot repository"
+echo "Uploaded to the staging repository"
+echo "After validating the release: "
+if [[ "${DEPLOY_OSSRH}" == "true" ]]; then
+ echo "* Login to https://oss.sonatype.org/#stagingRepositories"
+ echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort"
+fi
+if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then
+ echo "* Login to https://bintray.com/google/tensorflow/tensorflow"
+ echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort"
fi
diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py
index 2206d800ca..c620564072 100644
--- a/tensorflow/java/maven/tensorflow-android/update.py
+++ b/tensorflow/java/maven/tensorflow-android/update.py
@@ -86,19 +86,10 @@ def read_template(path):
def main():
args = get_args()
- # Artifacts are downloaded from the ci build. A SNAPSHOT release is
- # associated with artifacts from the last successful nightly build. Otherwise,
- # it comes from the officially blessed release artifacts.
- if args.version.endswith('SNAPSHOT'):
- info_url = ('https://ci.tensorflow.org/view/Nightly/job/nightly-android'
- '/lastSuccessfulBuild/api/json')
- aar_url = None
- build_type = 'nightly-android'
- else:
- release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow'
- info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version)
- aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version)
- build_type = 'release-android'
+ release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow'
+ info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version)
+ aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version)
+ build_type = 'release-android'
# Retrieve build information
build_info = get_json(info_url)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index abca956b97..752b49af04 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -144,22 +144,23 @@ public final class Graph implements AutoCloseable {
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
- * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
- * <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
- * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
- * <p>
- * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
- * shapes in {@code y}.
- * <p>
- * {@code prefix} is used as the name prefix applied to all nodes added to the graph to compute gradients. It must
- * be unique within the provided graph or the operation will fail.
- * <p>
- * If {@code prefix} is null, then one will be chosen automatically.
- *
- * @param prefix unique string prefix applied before the names of nodes added to the graph to compute gradients.
- * If null, a default one will be chosen.
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ *
+ * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
+ *
+ * <p>If {@code dx} is null, the implementation will use dx of {@link
+ * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
+ *
+ * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
+ * gradients. It must be unique within the provided graph or the operation will fail.
+ *
+ * <p>If {@code prefix} is null, then one will be chosen automatically.
+ *
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
@@ -192,12 +193,21 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native
- // handles of the gradient operations while the second holds the index of their output e.g. given
- // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first
+ // holds the native handles of the gradient operations while the second holds the index of
+ // their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), prefix, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(
+ ref.nativeHandle(),
+ prefix,
+ yHandles,
+ yIndices,
+ xHandles,
+ xIndices,
+ dxHandles,
+ dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -214,16 +224,16 @@ public final class Graph implements AutoCloseable {
/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code dy/dx_1, dy/dx_2...}
- * <p>
+ * <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
* a single output, {@code dx} is null and {@code prefix} is null.
- *
+ *
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(null, new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[] {y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -337,8 +347,15 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, String prefix, long[] inputHandles, int[] inputIndices,
- long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+ private static native long[] addGradients(
+ long handle,
+ String prefix,
+ long[] inputHandles,
+ int[] inputIndices,
+ long[] outputHandles,
+ int[] outputIndices,
+ long[] gradInputHandles,
+ int[] gradInputIndices);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 563ea66ef1..5a233bcc98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -135,8 +135,8 @@ public final class Scope {
* }</pre>
*
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a
- * set of related operations to the graph by calling other operator building code), the provided name
- * will act as a subscope to all underlying operators.
+ * set of related operations to the graph by calling other operator building code), the provided
+ * name will act as a subscope to all underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index 5432ff244e..eea9dc1c47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -18,7 +18,6 @@ package org.tensorflow.op.core;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
-
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
@@ -54,7 +53,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* Optional attributes for {@link Gradients}
*/
public static class Options {
-
+
/**
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return this option builder
@@ -63,23 +62,27 @@ public class Gradients implements Op, Iterable<Operand<?>> {
this.dx = dx;
return this;
}
-
+
private Iterable<? extends Operand<?>> dx;
-
+
private Options() {
}
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
+ *
* @param scope current graph scope
* @param y outputs of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param options carries optional attributes values
* @return a new instance of {@code Gradients}
*/
- public static Gradients create(Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope,
+ Iterable<? extends Operand<?>> y,
+ Iterable<? extends Operand<?>> x,
+ Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -88,20 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] dy = scope.graph().addGradients(
- scope.makeOpName("Gradients"),
- Operands.asOutputs(y),
- Operands.asOutputs(x),
- dx);
+ Output<?>[] dy =
+ scope
+ .graph()
+ .addGradients(
+ scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx);
return new Gradients(Arrays.asList(dy));
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
- * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
- * a single output.
- *
+ *
+ * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where
+ * {@code y} is a single output.
+ *
* @param scope current graph scope
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
@@ -109,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @return a new instance of {@code Gradients}
*/
@SuppressWarnings({"unchecked", "rawtypes"})
- public static Gradients create(Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
return create(scope, (Iterable) Arrays.asList(y), x, options);
}
@@ -133,11 +137,11 @@ public class Gradients implements Op, Iterable<Operand<?>> {
public List<Output<?>> dy() {
return dy;
}
-
+
/**
* Returns a symbolic handle to one of the gradient operation output
- * <p>
- * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
* this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
* gradients.<Float>dy(0)}
*
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 1bbda52641..f1744d8769 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
return ret;
}
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jstring prefix, jlongArray y_handles, jintArray y_indices,
- jlongArray x_handles, jintArray x_indices,
- jlongArray dx_handles, jintArray dx_indices) {
-
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
+ jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
+ jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index e483bf953b..215695cdfd 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -78,9 +78,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
* Method: name
* Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jstring, jlongArray, jintArray, jlongArray, jintArray,
- jlongArray, jintArray);
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
+ jintArray, jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index 56c8f22daa..7c05c1deaf 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -180,7 +179,7 @@ public class GraphTest {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
+
Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
@@ -228,14 +227,14 @@ public class GraphTest {
}
}
}
-
+
@Test
public void validateGradientsNames() {
try (Graph g = new Graph()) {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
-
+
Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
assertTrue(grad0[0].op().name().startsWith("gradients/"));
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 0e9c7df697..125de73554 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -21,7 +21,6 @@ import static org.junit.Assert.fail;
import java.util.HashMap;
import java.util.Map;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
index b75f79a421..3f49790b29 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -1,3 +1,18 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
package org.tensorflow.op.core;
import static org.junit.Assert.assertEquals;
@@ -5,7 +20,6 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -22,28 +36,25 @@ public class GradientsTest {
@Test
public void createGradients() {
- try (Graph g = new Graph();
+ try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
+
Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
-
+
assertNotNull(grads);
assertNotNull(grads.dy());
assertEquals(2, grads.dy().size());
try (Tensor<Float> c = Tensors.create(3.0f);
- TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
- sess.runner()
- .feed(x, c)
- .fetch(grads.dy(0))
- .fetch(grads.dy(1))
- .run())) {
-
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
+
assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
}
@@ -59,20 +70,17 @@ public class GradientsTest {
Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
+
Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
-
+
assertNotNull(grads);
assertNotNull(grads.dy());
assertEquals(1, grads.dy().size());
try (Tensor<Float> c = Tensors.create(3.0f);
- TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
- sess.runner()
- .feed(x, c)
- .fetch(grads.dy(0))
- .run())) {
-
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
+
assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
}
}
@@ -87,21 +95,19 @@ public class GradientsTest {
Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
+
Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
-
+
assertNotNull(grads1);
assertNotNull(grads1.dy());
assertEquals(1, grads1.dy().size());
try (Tensor<Float> c = Tensors.create(3.0f);
- TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
- sess.runner()
- .feed(x, c)
- .fetch(grads1.dy(0))
- .run())) {
-
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
+
assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
}
}
@@ -114,10 +120,10 @@ public class GradientsTest {
Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> y = TestUtil.square(g, "y", x);
-
+
Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
-
+
Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b5876c3457..d35731d3cd 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3658,6 +3658,7 @@ tf_cuda_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 180bb74d00..58a002c776 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -29,6 +29,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import device
+from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -630,7 +631,7 @@ class BaseSession(SessionInterface):
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
# pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
# pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -1235,8 +1236,12 @@ class BaseSession(SessionInterface):
return _fetch_handler_run
- # Captures the name of a node in an error status.
- _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
+ # Captures the name of a node in an error status. The regex below matches
+ # both the old and the new formats:
+ # Old format: [[Node: <node_name> = ...]]
+ # New format: [[{{node <node_name>}} = ...]]
+ _NODEDEF_NAME_RE = re.compile(
+ r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=')
def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
run_metadata):
@@ -1291,12 +1296,15 @@ class BaseSession(SessionInterface):
node_def = None
op = None
if m is not None:
- node_name = m.group(1)
+ node_name = m.group(3)
try:
op = self._graph.get_operation_by_name(node_name)
node_def = op.node_def
except KeyError:
pass
+ if (self._config is not None and
+ self._config.experimental.client_handles_error_formatting):
+ message = error_interpolation.interpolate(message, self._graph)
raise type(e)(node_def, op, message)
def _extend_graph(self):
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 1cdd8e0b6a..39a2922ac0 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -777,6 +777,7 @@ def TF_Reset(target, containers=None, config=None):
$1 = &types_local;
}
+%unignore TF_NewSessionRef;
%unignore SetRequireShapeInferenceFns;
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index b6481e7e29..bcd4af2912 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg =
"feed_dict must be a dictionary mapping strings to NumPy arrays.";
} // end namespace
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status) {
+ TF_Session* tf_session = TF_NewSession(graph, opts, status);
+ if (tf_session == nullptr) {
+ return nullptr;
+ }
+
+ Session* session = reinterpret_cast<Session*>(tf_session->session);
+ SessionRef* session_ref = new SessionRef(session);
+ tf_session->session = session_ref;
+ return tf_session;
+}
+
void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
const TF_Buffer* run_options, PyObject* feed_dict,
const NameVector& output_names,
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index cfd27c2bee..dab7e71aac 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status);
+
// Run the graph associated with the session starting with the
// supplied inputs[]. Regardless of success or failure, inputs[] are
// stolen by the implementation (i.e. the implementation will
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 38505c0a01..b66b87ce6c 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -318,7 +318,7 @@ tf_py_test(
],
)
-tf_py_test(
+cuda_py_test(
name = "iterator_ops_test",
size = "small",
srcs = ["iterator_ops_test.py"],
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 3ef22cf981..494df178df 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -57,6 +57,13 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
GLOBAL_ITERATORS = "iterators"
+def _device_stack_is_empty():
+ # pylint: disable=protected-access
+ device_stack = ops.get_default_graph()._device_functions_outer_to_inner
+ # pylint: enable=protected-access
+ return not bool(device_stack)
+
+
@tf_export("data.Iterator")
class Iterator(object):
"""Represents the state of iterating through a `Dataset`."""
@@ -174,7 +181,7 @@ class Iterator(object):
if shared_name is None:
shared_name = ""
if compat.forward_compatible(2018, 8, 3):
- if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ if _device_stack_is_empty():
with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_v2(
container="",
@@ -263,7 +270,7 @@ class Iterator(object):
nest.assert_same_structure(output_types, output_shapes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
if compat.forward_compatible(2018, 8, 3):
- if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ if _device_stack_is_empty():
with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
new file mode 100644
index 0000000000..2bd0b4320a
--- /dev/null
+++ b/tensorflow/python/distribute/BUILD
@@ -0,0 +1,43 @@
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "distribute_coordinator",
+ srcs = [
+ "distribute_coordinator.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "distribute_coordinator_test",
+ size = "small",
+ srcs = ["distribute_coordinator_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":distribute_coordinator",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distributed_framework_test_lib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
new file mode 100644
index 0000000000..04c50dbafc
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -0,0 +1,361 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A unified and split coordinator for distributed TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import json
+import os
+import threading
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import server_lib
+
+
+class _TaskType(object):
+ PS = "ps"
+ WORKER = "worker"
+ CHIEF = "chief"
+ EVALUATOR = "evaluator"
+
+
+_coordinator_context = threading.local()
+
+
+def get_current_coordinator_context():
+ """Returns the current coordinator context."""
+ try:
+ return _coordinator_context.current
+ except AttributeError:
+ return None
+
+
+class _Barrier(object):
+ """A reusable barrier class for worker synchronization."""
+
+ def __init__(self, num_participants):
+ """Initializes the barrier object.
+
+ Args:
+ num_participants: an integer which is the expected number of calls of
+ `wait` pass to through this barrier.
+ """
+ self._num_participants = num_participants
+ self._counter = 0
+ self._flag = False
+ self._local_sense = threading.local()
+ self._lock = threading.Lock()
+ self._condition = threading.Condition()
+
+ def wait(self):
+ """Waits until all other callers reach the same wait call."""
+ if not hasattr(self._local_sense, "value"):
+ self._local_sense.value = False
+ self._local_sense.value = not self._flag
+ with self._lock:
+ self._counter += 1
+ if self._counter == self._num_participants:
+ self._counter = 0
+ self._flag = self._local_sense.value
+ with self._condition:
+ while self._flag != self._local_sense.value:
+ self._condition.wait()
+ self._condition.notify_all()
+
+
+def _get_num_workers(cluster_spec):
+ """Gets number of workers including chief."""
+ if not cluster_spec:
+ return 0
+ return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
+ cluster_spec.as_dict().get(_TaskType.CHIEF, []))
+
+
+class _CoordinatorContext(object):
+ """The coordinator context class.
+
+ This context object provides configuration information for each task. One
+ context manager with a coordinator context object will be created per
+ invocation to the `worker_fn` where `get_current_coordinator_context` can be
+ called to access the coordinator context object.
+ """
+
+ def __init__(self,
+ cluster_spec,
+ task_type,
+ task_id,
+ between_graph=False,
+ rpc_layer="grpc",
+ worker_barrier=None):
+ """Initialize the coordinator context object.
+
+ Args:
+ cluster_spec: a ClusterSpec object. It can be empty or None in the local
+ training case.
+ task_type: a string indicating the role of the corresponding task, such as
+ "worker" or "ps". It can be None if it is local training or
+ `between_graph` is False.
+ task_id: an integer indicating id of the corresponding task. It can be
+ None if it is local training or `between_graph` is False.
+ between_graph: whether it is between-graph replication or not.
+ rpc_layer: optional string specifying the RPC protocol for communication
+ with worker masters. If None or empty, hosts in the `cluster_spec` will
+ be used directly.
+ worker_barrier: optional, the barrier object for worker synchronization.
+
+ Raises:
+ ValueError: if task_type or task_id is Node or empty and it is distributed
+ between-graph replicated training.
+ """
+ if cluster_spec and between_graph:
+ if not task_type or task_id is None:
+ raise ValueError("`task_type` and `task_id` must be set in the "
+ "distributed between-graph replicated training.")
+ if task_type not in cluster_spec.jobs:
+ raise ValueError("`task_type` %r not found in the `cluster_spec` %r" %
+ (task_type, cluster_spec))
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+ self._worker_barrier = worker_barrier
+ self._rpc_layer = rpc_layer
+ self._master_target = self._get_master_target()
+ self._num_workers = _get_num_workers(cluster_spec)
+ self._is_chief_node = self._is_chief()
+
+ def __enter__(self):
+ old_context = get_current_coordinator_context()
+ if old_context:
+ raise ValueError(
+ "You cannot run distribute coordinator in a `worker_fn`.")
+ _coordinator_context.current = self
+
+ def __exit__(self, unused_exception_type, unused_exception_value,
+ unused_traceback):
+ _coordinator_context.current = None
+
+ def _get_master_target(self):
+ """Return the master target for a task."""
+ # If cluster_spec is None or empty, we use local master.
+ if not self._cluster_spec:
+ return "local"
+
+ # If task_type is None, then it is in-graph replicated training. In this
+ # case we use the chief or first worker's master target.
+ if not self._task_type:
+ if _TaskType.CHIEF in self._cluster_spec.jobs:
+ assert not self.between_graph
+ task_type = _TaskType.CHIEF
+ task_id = 0
+ else:
+ assert _TaskType.WORKER in self._cluster_spec.jobs
+ task_type = _TaskType.WORKER
+ task_id = 0
+ else:
+ task_type = self._task_type
+ task_id = self._task_id
+
+ prefix = ""
+ if self._rpc_layer:
+ prefix = self._rpc_layer + "://"
+ return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
+
+ def _is_chief(self):
+ """Return whether the task is the chief worker."""
+ if (not self._cluster_spec or self._task_type in [_TaskType.CHIEF, None]):
+ return True
+
+ # If not local and chief not in the cluster_spec, use the first worker as
+ # chief.
+ if (_TaskType.CHIEF not in self._cluster_spec.jobs and
+ self._task_type == _TaskType.WORKER and self._task_id == 0):
+ return True
+ return False
+
+ def wait_for_other_workers(self):
+ """Waits for other workers to reach the same call to this method.
+
+ Raises:
+ ValueError: if `worker_barrier` is not passed to the __init__ method.
+ """
+ if not self._worker_barrier:
+ raise ValueError(
+ "`worker_barrier is not set in the coordinator context.`")
+ self._worker_barrier.wait()
+
+ @property
+ def distributed_mode(self):
+ """Whether it is distributed training or not."""
+ return bool(self._cluster_spec)
+
+ @property
+ def cluster_spec(self):
+ """Returns a copy of the cluster_spec object."""
+ return copy.deepcopy(self._cluster_spec)
+
+ @property
+ def task_type(self):
+ """Returns the role of the corresponing task."""
+ return self._task_type
+
+ @property
+ def task_id(self):
+ """Returns the id or index of the corresponing task."""
+ return self._task_id
+
+ @property
+ def master_target(self):
+ """Returns the session master for the corresponding task to connect to."""
+ return self._master_target
+
+ @property
+ def is_chief(self):
+ """Returns whether the task is a chief node."""
+ return self._is_chief_node
+
+ @property
+ def num_workers(self):
+ """Returns number of workers in the cluster, including chief."""
+ return self._num_workers
+
+
+def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer,
+ worker_barrier):
+ with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph,
+ rpc_layer, worker_barrier):
+ worker_fn()
+
+
+def run_distribute_coordinator(worker_fn,
+ cluster_spec=None,
+ between_graph=False,
+ rpc_layer=None):
+ """Run the coordinator for distributed TensorFlow.
+
+ This function runs a unified and split coordinator for distributed TensorFlow.
+ Given a `cluster_spec` specifying server addresses and their roles in a
+ cluster, this coordinator will figure out how to set them up, give the
+ underlying function the right targets for master sessions and coordinate their
+ training.
+
+ In addition to be the distribute coordinator, this is also the source of
+ configurations for each job in the distributed training. As there are multiple
+ ways to configure a distributed TensorFlow cluster, its context object
+ provides these configurations so that users or higher-level APIs don't have to
+ figure out the configuration for each job by themselves.
+
+ In the between-graph replicated training, this coordinator will create
+ multiple threads and each calls the `worker_fn` which is supposed to create
+ its own graph and connect to one worker master given by its coordinator
+ context. In the in-graph replicated training, it has only one thread calling
+ this `worker_fn`.
+
+ The `worker_fn` defines the training logic and is called under a its own
+ coordinator context which can be accessed to via
+ `get_current_coordinator_context`. A coordinator context provides access to
+ configurations for each task, e.g. the task_type, task_id, master target and
+ so on. Since `worker_fn` will be called in a thread and possibly multiple
+ times, caller should be careful when it accesses global data. For example, it
+ is unsafe to define flags in a `worker_fn` or to define different environment
+ variables for different `worker_fn`s.
+
+ The `worker_fn` for the between-graph replication is defined as if there are
+ only one worker corresponding to the `worker_fn` and possibly ps jobs. It
+ assigns variables to parameter servers and all other operations to that
+ worker. In the in-graph replication case, the `worker_fn` has to define
+ operations for all worker jobs. Using a distribution strategy can simplify the
+ `worker_fn` by not having to worry about the replication and device assignment
+ of variables and operations.
+
+ This method is intended to be invoked by high-level APIs so that users don't
+ have to explictly call it to run this coordinator. For those who don't use
+ high-level APIs, to change a program to use this coordinator, wrap everything
+ in a the program after global data definitions such as commandline flag
+ definition into the `worker_fn` and get task-specific configurations from
+ the coordinator context.
+
+ The `cluster_spec` can be either passed by the argument or parsed from the
+ "TF_CONFIG" envrionment variable. Example of a TF_CONFIG:
+ ```
+ cluster = {'chief': ['host0:2222'],
+ 'ps': ['host1:2222', 'host2:2222'],
+ 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
+ os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
+ ```
+
+ If `cluster_spec` is not given in any format, it becomes local training and
+ this coordinator will connect to a local session.
+
+ For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
+ will be created with its `task_type` set to "evaluator". If "evaluator" is not
+ set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
+ evaluation.
+
+ Args:
+ worker_fn: the function to be called and given the access to a coordinator
+ context object.
+ cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
+ in a cluster. If not set or empty, fall back to local training.
+ between_graph: a boolean. It is only useful when `cluster_spec` is set and
+ not empty. If true, it will use between-graph replicated training;
+ otherwise it will use in-graph replicated training.
+ rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
+
+ Raises:
+ ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
+ a ClusterSpec.
+ """
+ if not cluster_spec:
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ cluster_spec = tf_config.get("cluster", {})
+
+ if cluster_spec:
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ cluster_spec = server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ # TODO(yuefengz): validate cluster_spec.
+
+ threads = []
+ if cluster_spec and _TaskType.EVALUATOR in cluster_spec.jobs:
+ t = threading.Thread(
+ target=_run,
+ args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0, between_graph,
+ rpc_layer, None))
+ t.start()
+ threads.append(t)
+
+ if cluster_spec and between_graph:
+ worker_barrier = _Barrier(_get_num_workers(cluster_spec))
+ for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
+ for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
+ t = threading.Thread(
+ target=_run,
+ args=(worker_fn, cluster_spec, task_type, task_id, between_graph,
+ rpc_layer, worker_barrier))
+ t.start()
+ threads.append(t)
+ else:
+ # Local or in-graph replicated training.
+ _run(worker_fn, cluster_spec, None, None, between_graph, rpc_layer, None)
+
+ # TODO(yuefengz): wrapper threads into thread coordinator?
+ for t in threads:
+ t.join()
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
new file mode 100644
index 0000000000..82fd823352
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -0,0 +1,291 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for distribute coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import copy
+import threading
+import six
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+CHIEF = distribute_coordinator._TaskType.CHIEF
+WORKER = distribute_coordinator._TaskType.WORKER
+PS = distribute_coordinator._TaskType.PS
+EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
+
+NUM_WORKERS = 3
+NUM_PS = 2
+
+
+def _bytes_to_str(maybe_bytes):
+ if isinstance(maybe_bytes, six.string_types):
+ return maybe_bytes
+ else:
+ return str(maybe_bytes, "utf-8")
+
+
+class DistributeCoordinatorTest(test.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # We have to create a global in-process cluster because once an in-process
+ # tensorflow server is created, there is no way to terminate it. Please see
+ # multi_worker_test_base.py for more details.
+ cls._workers, cls._ps = test_util.create_local_cluster(
+ NUM_WORKERS, num_ps=NUM_PS)
+ cls._cluster_spec = {
+ WORKER: [_bytes_to_str(w.target) for w in cls._workers],
+ PS: [_bytes_to_str(ps.target) for ps in cls._ps]
+ }
+
+ def setUp(self):
+ self._result_correct = 0
+ self._lock = threading.Lock()
+ self._task_context = {}
+
+ @contextlib.contextmanager
+ def _test_session(self, target):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ config.graph_options.optimizer_options.opt_level = -1
+ with session.Session(graph=None, config=config, target=target) as sess:
+ yield sess
+
+ def _in_graph_worker_fn(self):
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ with self._test_session(target=context.master_target) as sess:
+ xs = []
+ expected = 0.0
+ for i in range(context.num_workers):
+ with ops.device("/job:worker/task:%d" % i):
+ x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
+ x_add = x.assign_add(float(i))
+ xs.append(x_add)
+ expected += i + 10.0
+
+ with ops.device("/job:worker/task:0"):
+ result = math_ops.add_n(xs)
+
+ variables.global_variables_initializer().run()
+ result_value = sess.run(result)
+ self.assertEqual(result_value, expected)
+ if result_value == expected:
+ self._result_correct += 1
+
+ def testInGraph(self):
+ """Test it runs in-graph replicated training correctly."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._in_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=False)
+ self.assertEqual(self._result_correct, 1)
+
+ def _between_graph_worker_fn(self):
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ with self._test_session(target=context.master_target) as sess:
+ with ops.device("/job:ps/task:0"):
+ # TODO(yuefengz): investigate why not using resource variable will make
+ # the test flaky.
+ x = variable_scope.get_variable(
+ "x", initializer=10.0, use_resource=True)
+ with ops.device("/job:ps/task:1"):
+ y = variable_scope.get_variable(
+ "y", initializer=20.0, use_resource=True)
+
+ x_add = x.assign_add(2.0)
+ y_sub = y.assign_sub(2.0)
+ train_op = control_flow_ops.group([x_add, y_sub])
+
+ if context.is_chief:
+ variables.global_variables_initializer().run()
+
+ # Synchronize workers after initializaton.
+ context.wait_for_other_workers()
+
+ sess.run(train_op)
+
+ # Synchronize workers after one step to make sure they all have finished
+ # training.
+ context.wait_for_other_workers()
+
+ x_val, y_val = sess.run([x, y])
+
+ self.assertEqual(x_val, 16.0)
+ self.assertEqual(y_val, 14.0)
+ if x_val == 16.0 and y_val == 14.0:
+ with self._lock:
+ self._result_correct += 1
+
+ def testBetweenGraph(self):
+ """Test it runs between-graph replicated training correctly."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=True)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def _dump_task_context(self):
+ """Dumps the propoerties of each coordinator context.
+
+ It dumps the context properties to a dict mapping from task_type to a list
+ of tuples of master_target, num_workers, is_chief and distribute_mode, where
+ the list is indexed by the task_id.
+ """
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ task_type = str(context.task_type)
+ task_id = context.task_id or 0
+ with self._lock:
+ if task_type not in self._task_context:
+ self._task_context[task_type] = []
+ while len(self._task_context[task_type]) <= task_id:
+ self._task_context[task_type].append(None)
+ self._task_context[task_type][task_id] = (context.master_target,
+ context.num_workers,
+ context.is_chief,
+ context.distributed_mode)
+
+ def testBetweenGraphContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=self._cluster_spec,
+ between_graph=True)
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue(WORKER in self._task_context)
+ self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._task_context[WORKER][0],
+ (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
+ self.assertEqual(
+ self._task_context[WORKER][1],
+ (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
+ self.assertEqual(
+ self._task_context[WORKER][2],
+ (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
+
+ def testInGraphContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=self._cluster_spec,
+ between_graph=False)
+
+ # There is only a "None" task in the dumped task context.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue("None" in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._task_context["None"][0],
+ (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
+
+ def testLocalContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context, cluster_spec=None, between_graph=True)
+
+ # There is only a "None" task.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue("None" in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
+
+ def testBetweenGraphContextWithChief(self):
+ # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
+ cluster_spec = copy.deepcopy(self._cluster_spec)
+ cluster_spec[CHIEF] = ["fake_chief"]
+
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=cluster_spec,
+ between_graph=True,
+ rpc_layer="grpc")
+
+ # There are one CHIEF and three workers.
+ self.assertEqual(len(self._task_context), 2)
+ self.assertTrue(CHIEF in self._task_context)
+ self.assertTrue(WORKER in self._task_context)
+ self.assertEqual(len(self._task_context[CHIEF]), 1)
+ self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context[CHIEF][0],
+ ("grpc://fake_chief", 4, True, True))
+ self.assertEqual(self._task_context[WORKER][0],
+ ("grpc://" + _bytes_to_str(self._workers[0].target),
+ NUM_WORKERS + 1, False, True))
+ self.assertEqual(self._task_context[WORKER][1],
+ ("grpc://" + _bytes_to_str(self._workers[1].target),
+ NUM_WORKERS + 1, False, True))
+ self.assertEqual(self._task_context[WORKER][2],
+ ("grpc://" + _bytes_to_str(self._workers[2].target),
+ NUM_WORKERS + 1, False, True))
+
+ def testInGraphContextWithEval(self):
+ # Adds a EVALUATOR job.
+ cluster_spec = copy.deepcopy(self._cluster_spec)
+ cluster_spec[EVALUATOR] = ["fake_evaluator"]
+
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context, cluster_spec=cluster_spec, between_graph=False)
+
+ # There are one "None" task and one EVALUATOR task.
+ self.assertEqual(len(self._task_context), 2)
+ self.assertTrue("None" in self._task_context)
+ self.assertTrue(EVALUATOR in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._task_context[EVALUATOR]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context["None"][0],
+ (_bytes_to_str(self._workers[0].target), 3, True, True))
+ self.assertEqual(self._task_context[EVALUATOR][0],
+ ("fake_evaluator", 3, False, True))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 5e4f9e29da..99129c2537 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -93,10 +93,11 @@ def capture_value(tensor_map, value, dtype, name):
class CapturingGraph(ops.Graph):
"""Graph used when constructing eager functions."""
- def __init__(self, captures):
+ def __init__(self):
super(CapturingGraph, self).__init__()
self._building_function = True
- self.captures = captures
+ # Maps external tensor id -> internal tensor (e.g. input placeholder).
+ self.captures = {}
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
@@ -471,8 +472,7 @@ class GraphModeFunction(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
filtered_outputs = [x for x in self._python_returns if x is not None]
- captures = {}
- backwards_graph = CapturingGraph(captures)
+ backwards_graph = CapturingGraph()
backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
for collection in self._graph.collections:
backwards_graph.get_collection_ref(
@@ -491,6 +491,7 @@ class GraphModeFunction(object):
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
+ captures = backwards_graph.captures
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
@@ -660,8 +661,7 @@ def _deterministic_dict_values(kwds):
def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- captures = {}
- tmp_graph = CapturingGraph(captures)
+ tmp_graph = CapturingGraph()
# Inherit the graph key, since this is used for matching variables in
# optimizers.
tmp_graph._graph_key = graph_key # pylint: disable=protected-access
@@ -703,6 +703,7 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
if x is not None
]
+ captures = tmp_graph.captures
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 2c6f04d8ad..2dc5060984 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -280,8 +280,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
# This graph will store both the initialization and the call version of the
# wrapped function. It will later be used by the backprop code to build the
# backprop graph, if necessary.
- captures = {}
- tmp_graph = function.CapturingGraph(captures)
+ tmp_graph = function.CapturingGraph()
# Inherit the graph key from the original graph to ensure optimizers don't
# misbehave.
tmp_graph._container = container # pylint: disable=protected-access
@@ -331,6 +330,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
sorted_variables = sorted(variable_captures.variables.values(),
key=lambda x: x.name)
+ captures = tmp_graph.captures
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 4d28e98961..0eabea321c 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -845,11 +845,9 @@ int64_t get_uid() {
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
void TFE_DeleteContextCapsule(PyObject* context) {
- TF_Status* status = TF_NewStatus();
TFE_Context* ctx =
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
- TFE_DeleteContext(ctx, status);
- TF_DeleteStatus(status);
+ TFE_DeleteContext(ctx);
}
static tensorflow::int64 MakeInt(PyObject* integer) {
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index fd46163050..817c8e6848 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -171,6 +171,7 @@ py_test(
name = "baseline_test",
size = "medium",
srcs = ["canned/baseline_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_pip",
@@ -207,6 +208,7 @@ py_test(
name = "boosted_trees_test",
size = "medium",
srcs = ["canned/boosted_trees_test.py"],
+ shard_count = 2,
srcs_version = "PY2AND3",
tags = [
"optonly",
@@ -676,6 +678,7 @@ py_test(
name = "keras_test",
size = "large",
srcs = ["keras_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_windows",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 3292e2724d..8b423f76de 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -46,7 +46,7 @@ from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
- 'min_node_weight', 'center_bias'
+ 'min_node_weight', 'center_bias', 'pruning_mode'
])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
@@ -410,9 +410,20 @@ class _EnsembleGrower(object):
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ Raises:
+ ValueError: when pruning mode is invalid or pruning is used and no tree
+ complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
+ # pylint: disable=protected-access
+ self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
+ tree_hparams.pruning_mode)
+
+ if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
+ and tree_hparams.tree_complexity <= 0):
+ raise ValueError('For pruning, tree_complexity must be positive.')
+ # pylint: enable=protected-access
@abc.abstractmethod
def center_bias(self, center_bias_var, gradients, hessians):
@@ -500,7 +511,7 @@ class _EnsembleGrower(object):
right_node_contribs=right_node_contribs_list,
learning_rate=self._tree_hparams.learning_rate,
max_depth=self._tree_hparams.max_depth,
- pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+ pruning_mode=self._pruning_mode_parsed)
return grow_op
@@ -675,6 +686,7 @@ def _bt_model_fn(
is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
center_bias = tree_hparams.center_bias
+
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
@@ -925,7 +937,8 @@ class BoostedTreesClassifier(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -999,7 +1012,11 @@ class BoostedTreesClassifier(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
-
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -1012,9 +1029,9 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes, weight_column, label_vocabulary=label_vocabulary)
# HParams for the model.
- tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
- l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_hparams = _TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -1058,7 +1075,8 @@ class BoostedTreesRegressor(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -1125,6 +1143,11 @@ class BoostedTreesRegressor(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -1136,9 +1159,9 @@ class BoostedTreesRegressor(estimator.Estimator):
head = _create_regression_head(label_dimension, weight_column)
# HParams for the model.
- tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
- l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_hparams = _TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index f807641057..ec597e4686 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -1508,7 +1508,8 @@ class ModelFnTests(test_util.TensorFlowTestCase):
l2=0.01,
tree_complexity=0.,
min_node_weight=0.,
- center_bias=center_bias)
+ center_bias=center_bias,
+ pruning_mode='none')
estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access
features=features,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index cc5a61b54e..52b19466eb 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -179,45 +179,16 @@ class Estimator(object):
"""
Estimator._assert_members_are_not_overridden(self)
- if config is None:
- self._config = run_config.RunConfig()
- logging.info('Using default config.')
- else:
- if not isinstance(config, run_config.RunConfig):
- raise ValueError(
- 'config must be an instance of RunConfig, but provided %s.' %
- config)
- self._config = config
+ config = maybe_overwrite_model_dir_and_session_config(config, model_dir)
+ self._config = config
# The distribute field contains an instance of DistributionStrategy.
self._distribution = self._config.train_distribute
-
# Model directory.
- model_dir = compat_internal.path_to_str(model_dir)
- if (model_dir is not None) and (self._config.model_dir is not None):
- if model_dir != self._config.model_dir:
- # TODO(alanyee): remove this suppression after it is no longer needed
- # pylint: disable=g-doc-exception
- raise ValueError(
- "model_dir are set both in constructor and RunConfig, but with "
- "different values. In constructor: '{}', in RunConfig: "
- "'{}' ".format(model_dir, self._config.model_dir))
- # pylint: enable=g-doc-exception
-
- self._model_dir = model_dir or self._config.model_dir
- if self._model_dir is None:
- self._model_dir = tempfile.mkdtemp()
- logging.warning('Using temporary folder as model directory: %s',
- self._model_dir)
- if self._config.model_dir is None:
- self._config = self._config.replace(model_dir=self._model_dir)
+ self._model_dir = self._config.model_dir
+ self._session_config = self._config.session_config
logging.info('Using config: %s', str(vars(self._config)))
- if self._config.session_config is None:
- self._session_config = run_config.get_default_session_config()
- else:
- self._session_config = self._config.session_config
-
self._device_fn = (
self._config.device_fn or _get_replica_device_setter(self._config))
@@ -1542,6 +1513,49 @@ class Estimator(object):
warm_starting_util.warm_start(*self._warm_start_settings)
+def maybe_overwrite_model_dir_and_session_config(config, model_dir):
+ """Overwrite estimator config by `model_dir` and `session_config` if needed.
+
+ Args:
+ config: Original estimator config.
+ model_dir: Estimator model checkpoint directory.
+
+ Returns:
+ Overwritten estimator config.
+
+ Raises:
+ ValueError: Model directory inconsistent between `model_dir` and `config`.
+ """
+
+ if config is None:
+ config = run_config.RunConfig()
+ logging.info('Using default config.')
+ if not isinstance(config, run_config.RunConfig):
+ raise ValueError(
+ 'config must be an instance of `RunConfig`, but provided %s.' % config)
+
+ if config.session_config is None:
+ session_config = run_config.get_default_session_config()
+ config = run_config.RunConfig.replace(config, session_config=session_config)
+
+ model_dir = compat_internal.path_to_str(model_dir)
+ if model_dir is not None:
+ if (getattr(config, 'model_dir', None) is not None and
+ config.model_dir != model_dir):
+ raise ValueError(
+ "`model_dir` are set both in constructor and `RunConfig`, but with "
+ "different values. In constructor: '{}', in `RunConfig`: "
+ "'{}' ".format(model_dir, config.model_dir))
+ if model_dir:
+ config = run_config.RunConfig.replace(config, model_dir=model_dir)
+ if getattr(config, 'model_dir', None) is None:
+ model_dir = tempfile.mkdtemp()
+ logging.warning('Using temporary folder as model directory: %s', model_dir)
+ config = run_config.RunConfig.replace(config, model_dir=model_dir)
+
+ return config
+
+
def create_per_tower_ready_op(scaffold):
"""Create a Scaffold.ready_op inside a tower."""
if scaffold.ready_op:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 8bc410ba0b..68fc5bcadf 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -175,7 +175,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
class EstimatorConstructorTest(test.TestCase):
def test_config_must_be_a_run_config(self):
- with self.assertRaisesRegexp(ValueError, 'an instance of RunConfig'):
+ with self.assertRaisesRegexp(ValueError, 'an instance of `RunConfig`'):
estimator.Estimator(model_fn=None, config='NotARunConfig')
def test_model_fn_must_be_provided(self):
@@ -228,6 +228,15 @@ class EstimatorConstructorTest(test.TestCase):
self.assertEqual(_TMP_DIR, est.config.model_dir)
self.assertEqual(_TMP_DIR, est.model_dir)
+ def test_empty_model_dir(self):
+ def model_fn(features, labels):
+ _, _ = features, labels
+
+ with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
+ est = estimator.Estimator(model_fn=model_fn, model_dir='')
+ self.assertEqual(_TMP_DIR, est.config.model_dir)
+ self.assertEqual(_TMP_DIR, est.model_dir)
+
def test_model_dir_in_run_config(self):
class FakeConfig(run_config.RunConfig):
@@ -272,7 +281,7 @@ class EstimatorConstructorTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
- 'model_dir are set both in constructor and RunConfig, but '
+ '`model_dir` are set both in constructor and `RunConfig`, but '
'with different values'):
estimator.Estimator(
model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 70517ae278..079560c495 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,14 +21,11 @@ from __future__ import print_function
import os
import re
-import tempfile
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import run_config as run_config_lib
-from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -473,43 +470,6 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
return latest_path
-def _maybe_overwrite_model_dir_and_session_config(config, model_dir):
- """Overwrite estimator config by `model_dir` and `session_config` if needed.
-
- Args:
- config: Original estimator config.
- model_dir: Estimator model checkpoint directory.
-
- Returns:
- Overwritten estimator config.
-
- Raises:
- ValueError: Model directory inconsistent between `model_dir` and `config`.
- """
-
- default_session_config = run_config_lib.get_default_session_config()
- if isinstance(config, dict):
- config = RunConfig(**config)
- elif config is None:
- config = RunConfig(session_config=default_session_config)
- if config.session_config is None:
- config = RunConfig.replace(config, session_config=default_session_config)
-
- if model_dir is not None:
- if (getattr(config, 'model_dir', None) is not None and
- config.model_dir != model_dir):
- raise ValueError(
- "`model_dir` are set both in constructor and `RunConfig`, but with "
- "different values. In constructor: '{}', in `RunConfig`: "
- "'{}' ".format(model_dir, config.model_dir))
- config = RunConfig.replace(config, model_dir=model_dir)
- elif getattr(config, 'model_dir', None) is None:
- model_dir = tempfile.mkdtemp()
- config = RunConfig.replace(config, model_dir=model_dir)
-
- return config
-
-
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@@ -527,9 +487,9 @@ def model_to_estimator(keras_model=None,
format, which can be generated with the `save()` method of a Keras model.
This argument is mutually exclusive with `keras_model`.
custom_objects: Dictionary for custom objects.
- model_dir: Directory to save Estimator model parameters, graph, summary
+ model_dir: Directory to save `Estimator` model parameters, graph, summary
files for TensorBoard, etc.
- config: Configuration object.
+ config: `RunConfig` to config `Estimator`.
Returns:
An Estimator from given keras model.
@@ -566,7 +526,8 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- config = _maybe_overwrite_model_dir_and_session_config(config, model_dir)
+ config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
+ model_dir)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
if _any_weight_initialized(keras_model):
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index cf4ec7f4da..332e385726 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -275,11 +275,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.test_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
- # Also use dict config argument to get test coverage for that line.
- config={
- 'tf_random_seed': _RANDOM_SEED,
- 'model_dir': self._base_dir,
- })
+ config=self._config)
before_eval_results = est_keras.evaluate(
input_fn=eval_input_fn, steps=1)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index a9fd8f8e1a..9db9ccd01d 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -380,15 +380,12 @@ def _maybe_add_default_serving_output(export_outputs):
return export_outputs
-class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn',
- 'host_call'])):
+class _TPUEstimatorSpec(
+ collections.namedtuple('TPUEstimatorSpec', [
+ 'mode', 'predictions', 'loss', 'train_op', 'eval_metrics',
+ 'export_outputs', 'scaffold_fn', 'host_call', 'training_hooks',
+ 'evaluation_hooks', 'prediction_hooks'
+ ])):
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See
@@ -404,17 +401,24 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
- host_call=None):
+ host_call=None,
+ training_hooks=None,
+ evaluation_hooks=None,
+ prediction_hooks=None):
"""Creates a `_TPUEstimatorSpec` instance."""
- return super(_TPUEstimatorSpec, cls).__new__(cls,
- mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metrics=eval_metrics,
- export_outputs=export_outputs,
- scaffold_fn=scaffold_fn,
- host_call=host_call)
+ return super(_TPUEstimatorSpec, cls).__new__(
+ cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn,
+ host_call=host_call,
+ training_hooks=training_hooks,
+ evaluation_hooks=evaluation_hooks,
+ prediction_hooks=prediction_hooks)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
@@ -423,12 +427,16 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
else:
metric_fn, tensors = self.eval_metrics
eval_metric_ops = metric_fn(**tensors)
- return EstimatorSpec(mode=self.mode,
- predictions=self.predictions,
- loss=self.loss,
- train_op=self.train_op,
- eval_metric_ops=eval_metric_ops,
- export_outputs=self.export_outputs)
+ return EstimatorSpec(
+ mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs,
+ training_hooks=self.training_hooks,
+ evaluation_hooks=self.evaluation_hooks,
+ prediction_hooks=self.prediction_hooks)
def _check_is_tensor_or_operation(x, name):
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index a79073b748..7719d03019 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -87,6 +87,53 @@ def _parse_message(message):
return seps, tags
+def _compute_device_summary_from_list(device_assignment_list, prefix=""):
+ """Return a summary of an op's device function stack.
+
+ Args:
+ device_assignment_list: The op._device_assignments list.
+ prefix: An optional string prefix used before each line of the multi-
+ line string returned by this function.
+
+ Returns:
+ A multi-line string similar to:
+ Device assignments active during op creation:
+ with tf.device(/cpu:0): <test_1.py:27>
+ with tf.device(some_func<foo.py, 123>): <test_2.py:38>
+ The first line will have no padding to its left by default. Subsequent
+ lines will have two spaces of left-padding. Use the prefix argument
+ to increase indentation.
+ """
+ if not device_assignment_list:
+ message = "No device assignments were active during op creation."
+ return prefix + message
+
+ str_list = []
+ str_list.append("%sDevice assignments active during op creation:" % prefix)
+
+ for traceable_obj in device_assignment_list:
+ location_summary = "<{file}:{line}>".format(file=traceable_obj.filename,
+ line=traceable_obj.lineno)
+ subs = {
+ "prefix": prefix,
+ "indent": " ",
+ "dev_name": traceable_obj.obj,
+ "loc": location_summary,
+ }
+ str_list.append(
+ "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
+
+ return "\n".join(str_list)
+
+
+def _compute_device_assignment_summary_from_op(op, prefix=""):
+ if not op:
+ return ""
+ # pylint: disable=protected-access
+ return _compute_device_summary_from_list(op._device_assignments, prefix)
+ # pylint: enable=protected-access
+
+
def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
"""Return a summary of an op's colocation stack.
@@ -203,6 +250,7 @@ def _compute_field_dict(op):
"file": default_value,
"line": default_value,
"colocations": default_value,
+ "devices": default_value,
}
frame = _get_defining_frame_from_op(op)
if frame:
@@ -211,6 +259,9 @@ def _compute_field_dict(op):
colocation_summary = _compute_colocation_summary_from_op(op)
if colocation_summary:
field_dict["colocations"] = colocation_summary
+ device_summary = _compute_device_assignment_summary_from_op(op)
+ if device_summary:
+ field_dict["devices"] = device_summary
return field_dict
@@ -233,6 +284,8 @@ def interpolate(error_message, graph):
node_name_to_substitution_dict = {}
for name in [t.name for t in tags]:
+ if name in node_name_to_substitution_dict:
+ continue
try:
op = graph.get_operation_by_name(name)
except KeyError:
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index 1e5cb73854..fbf182879b 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -57,13 +57,32 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
op._traceback = stack
-def assert_node_in_colocation_summary(test_obj, colocation_summary_string,
- name, filename="", lineno=""):
- lineno = str(lineno)
- name_phrase = "colocate_with(%s)" % name
- for term in [name_phrase, filename, lineno]:
- test_obj.assertIn(term, colocation_summary_string)
- test_obj.assertNotIn("loc:@", colocation_summary_string)
+class ComputeDeviceSummaryFromOpTest(test.TestCase):
+
+ def testCorrectFormatWithActiveDeviceAssignments(self):
+ assignments = []
+ assignments.append(
+ traceable_stack.TraceableObject("/cpu:0",
+ filename="hope.py",
+ lineno=24))
+ assignments.append(
+ traceable_stack.TraceableObject("/gpu:2",
+ filename="please.py",
+ lineno=42))
+
+ summary = error_interpolation._compute_device_summary_from_list(
+ assignments, prefix=" ")
+
+ self.assertIn("tf.device(/cpu:0)", summary)
+ self.assertIn("<hope.py:24>", summary)
+ self.assertIn("tf.device(/gpu:2)", summary)
+ self.assertIn("<please.py:42>", summary)
+
+ def testCorrectFormatWhenNoColocationsWereActive(self):
+ device_assignment_list = []
+ summary = error_interpolation._compute_device_summary_from_list(
+ device_assignment_list, prefix=" ")
+ self.assertIn("No device assignments", summary)
class ComputeColocationSummaryFromOpTest(test.TestCase):
@@ -81,15 +100,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
}
summary = error_interpolation._compute_colocation_summary_from_dict(
colocation_dict, prefix=" ")
- assert_node_in_colocation_summary(self,
- summary,
- name="test_node_1",
- filename="test_1.py",
- lineno=27)
- assert_node_in_colocation_summary(self, summary,
- name="test_node_2",
- filename="test_2.py",
- lineno=38)
+ self.assertIn("colocate_with(test_node_1)", summary)
+ self.assertIn("<test_1.py:27>", summary)
+ self.assertIn("colocate_with(test_node_2)", summary)
+ self.assertIn("<test_2.py:38>", summary)
def testCorrectFormatWhenNoColocationsWereActive(self):
colocation_dict = {}
@@ -98,9 +112,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
self.assertIn("No node-device colocations", summary)
-class InterpolateTest(test.TestCase):
+class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
+ ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
@@ -177,9 +192,57 @@ class InterpolateTest(test.TestCase):
self.assertRegexpMatches(interpolated_string, expected_regex)
+class InterpolateDeviceSummaryTest(test.TestCase):
+
+ def _fancy_device_function(self, unused_op):
+ return "/cpu:*"
+
+ def setUp(self):
+ ops.reset_default_graph()
+ self.zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ self.one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ self.two = constant_op.constant([2.0], name="two")
+ with ops.device(self._fancy_device_function):
+ self.three = constant_op.constant(3.0, name="three")
+
+ self.graph = self.three.graph
+
+ def testNodeZeroHasNoDeviceSummaryInfo(self):
+ message = "^^node:zero:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ self.assertIn("No device assignments were active", result)
+
+ def testNodeOneHasExactlyOneInterpolatedDevice(self):
+ message = "^^node:one:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(1, num_devices)
+ self.assertIn("tf.device(/cpu)", result)
+
+ def testNodeTwoHasTwoInterpolatedDevice(self):
+ message = "^^node:two:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(2, num_devices)
+ self.assertIn("tf.device(/cpu)", result)
+ self.assertIn("tf.device(/cpu:0)", result)
+
+ def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
+ message = "^^node:three:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(1, num_devices)
+ name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
+ expected_re = r"with tf.device\(.*%s\)" % name_re
+ self.assertRegexpMatches(result, expected_re)
+
+
class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self):
+ ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two")
@@ -203,12 +266,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
def testNodeThreeHasColocationInterpolation(self):
message = "^^node:Three_with_one:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="One")
+ self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
message = "^^node:Four_with_three:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="Three_with_one")
+ self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s"
@@ -217,8 +280,8 @@ class InterpolateColocationSummaryTest(test.TestCase):
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
message = "^^node:Five_with_one_with_two:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="One")
- assert_node_in_colocation_summary(self, result, name="Two")
+ self.assertIn("colocate_with(One)", result)
+ self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
message = "^^node:One:${colocations}^^"
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 6525607fae..c76743d2c6 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -38,8 +38,8 @@ from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
-from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# This is to avoid a circular dependency with cond_v2_impl.
@@ -255,9 +255,12 @@ class _DefinedFunction(object):
# Constructed only when C API is enabled, lazily
self._c_func = None
self._sub_functions = dict() # Constructed with _definition or _c_func
- device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
+ # pylint: enable=protected-access
+
# Get the innermost device if possbile.
- self._caller_device = device_stack[-1] if device_stack else None
+ self._caller_device = device_funcs[-1] if device_funcs else None
# Cached OpDef for this function. When C API is enabled, this is
# the only part of FunctionDef that we cache in Python. When C API
@@ -354,7 +357,7 @@ class _DefinedFunction(object):
if self._func_name:
base_func_name = self._func_name
else:
- base_func_name = _get_func_name(self._func)
+ base_func_name = function_utils.get_func_name(self._func)
if self._grad_func:
base_func_name += ("_%s" % self._grad_func.name)
kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
@@ -841,7 +844,7 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
ValueError: if func returns None.
"""
if not name:
- name = _get_func_name(func)
+ name = function_utils.get_func_name(func)
func_graph = _FuncGraph(name, capture_by_value)
with func_graph.as_default(), ops.device(device):
@@ -1139,19 +1142,6 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
return attrs
-def _get_func_name(func):
- _, func = tf_decorator.unwrap(func)
- if callable(func):
- if tf_inspect.isfunction(func):
- return func.__name__
- elif tf_inspect.ismethod(func):
- return "%s.%s" % (func.__self__.__name__, func.__name__)
- else: # Probably a class instance with __call__
- return type(func)
- else:
- raise ValueError("Argument must be callable")
-
-
def get_extra_vars():
"""Returns the captured variables by the function.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 0fd028ebf0..c25e29b0f4 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -50,14 +50,15 @@ from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import traceable_stack
from tensorflow.python.framework import versions
-from tensorflow.python.util import tf_stack
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@@ -73,6 +74,31 @@ def tensor_id(tensor):
return tensor._id # pylint: disable=protected-access
+class _UserDeviceSpec(object):
+ """Store user-specified device and provide computation of merged device."""
+
+ def __init__(self, device_name_or_function):
+ self._device_name_or_function = device_name_or_function
+
+ self.display_name = str(self._device_name_or_function)
+ if callable(self._device_name_or_function):
+ dev_func = self._device_name_or_function
+ func_name = function_utils.get_func_name(dev_func)
+ func_code = function_utils.get_func_code(dev_func)
+ if func_code:
+ fname = func_code.co_filename
+ lineno = func_code.co_firstlineno
+ else:
+ fname = "unknown"
+ lineno = -1
+ self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
+
+ self.function = self._device_name_or_function
+ if not (self._device_name_or_function is None or
+ callable(self._device_name_or_function)):
+ self.function = pydev.merge_device(self._device_name_or_function)
+
+
class _NullContextmanager(object):
def __enter__(self):
@@ -1719,7 +1745,12 @@ class Operation(object):
self._id_value = self._graph._next_id()
self._original_op = original_op
self._traceback = tf_stack.extract_stack()
- # List of traceable_stack.TraceableObjects for colocation context managers.
+
+ # List of _UserDevSpecs holding code location of device context manager
+ # invocations and the users original argument to them.
+ self._device_code_locations = None
+ # Dict mapping op name to file and line information for op colocation
+ # context managers.
self._colocation_code_locations = None
self._control_flow_context = self.graph._get_control_flow_context()
# pylint: enable=protected-access
@@ -1861,6 +1892,37 @@ class Operation(object):
return c_api.TF_OperationDevice(self._c_op)
@property
+ def _device_assignments(self):
+ """Code locations for device context managers active at op creation.
+
+ This property will return a list of traceable_stack.TraceableObject
+ instances where .obj is a string representing the assigned device
+ (or information about the function that would be applied to this op
+ to compute the desired device) and the filename and lineno members
+ record the location of the relevant device context manager.
+
+ For example, suppose file_a contained these lines:
+
+ file_a.py:
+ 15: with tf.device('/gpu:0'):
+ 16: node_b = tf.constant(4, name='NODE_B')
+
+ Then a TraceableObject t_obj representing the device context manager
+ would have these member values:
+
+ t_obj.obj -> '/gpu:0'
+ t_obj.filename = 'file_a.py'
+ t_obj.lineno = 15
+
+ and node_b.op._device_assignments would return the list [t_obj].
+
+ Returns:
+ [str: traceable_stack.TraceableObject, ...] as per this method's
+ description, above.
+ """
+ return self._device_code_locations or []
+
+ @property
def _colocation_dict(self):
"""Code locations for colocation context managers active at op creation.
@@ -1881,11 +1943,10 @@ class Operation(object):
would have these member values:
t_obj.obj -> None
- t_obj.name = 'NODE_A'
t_obj.filename = 'file_a.py'
t_obj.lineno = 15
- and node_b.op._colocation_code_locations would return the dictionary
+ and node_b.op._colocation_dict would return the dictionary
{ 'NODE_A': t_obj }
@@ -2735,7 +2796,7 @@ class Graph(object):
# Functions that will be applied to choose a device if none is specified.
# After switch_to_thread_local(), self._thread_local._device_function_stack
# is used instead.
- self._graph_device_function_stack = []
+ self._graph_device_function_stack = traceable_stack.TraceableStack()
# Default original_op applied to new ops.
self._default_original_op = None
# Current control flow context. It could be either CondContext or
@@ -3779,8 +3840,8 @@ class Graph(object):
Nothing.
"""
old_original_op = self._default_original_op
+ self._default_original_op = op
try:
- self._default_original_op = op
yield
finally:
self._default_original_op = old_original_op
@@ -3897,15 +3958,15 @@ class Graph(object):
# op name regex, which constrains the initial character.
if not _VALID_OP_NAME_REGEX.match(name):
raise ValueError("'%s' is not a valid scope name" % name)
+ old_stack = self._name_stack
+ if not name: # Both for name=None and name="" we re-set to empty scope.
+ new_stack = None
+ elif name[-1] == "/":
+ new_stack = _name_from_scope_name(name)
+ else:
+ new_stack = self.unique_name(name)
+ self._name_stack = new_stack
try:
- old_stack = self._name_stack
- if not name: # Both for name=None and name="" we re-set to empty scope.
- new_stack = None
- elif name[-1] == "/":
- new_stack = _name_from_scope_name(name)
- else:
- new_stack = self.unique_name(name)
- self._name_stack = new_stack
yield "" if new_stack is None else new_stack + "/"
finally:
self._name_stack = old_stack
@@ -3986,8 +4047,8 @@ class Graph(object):
ignore_existing=False):
with self.colocate_with(op, ignore_existing):
if gradient_uid is not None and self._control_flow_context is not None:
+ self._control_flow_context.EnterGradientColocation(op, gradient_uid)
try:
- self._control_flow_context.EnterGradientColocation(op, gradient_uid)
yield
finally:
self._control_flow_context.ExitGradientColocation(op, gradient_uid)
@@ -4029,7 +4090,6 @@ class Graph(object):
Yields:
A context manager that specifies the op with which to colocate
newly created ops.
-
"""
if op is None and not ignore_existing:
raise ValueError("Trying to reset colocation (op is None) but "
@@ -4047,7 +4107,7 @@ class Graph(object):
# In the future, a caller may specify that device_functions win
# over colocation, in which case we can add support.
device_fn_tmp = self._device_function_stack
- self._device_function_stack = []
+ self._device_function_stack = traceable_stack.TraceableStack()
if ignore_existing:
current_stack = self._colocation_stack
@@ -4071,6 +4131,13 @@ class Graph(object):
if ignore_existing:
self._colocation_stack = current_stack
+ def _add_device_to_stack(self, device_name_or_function, offset=0):
+ """Add device to stack manually, separate from a context manager."""
+ total_offset = 1 + offset
+ spec = _UserDeviceSpec(device_name_or_function)
+ self._device_function_stack.push_obj(spec, offset=total_offset)
+ return spec
+
@tf_contextlib.contextmanager
def device(self, device_name_or_function):
# pylint: disable=line-too-long
@@ -4128,31 +4195,26 @@ class Graph(object):
Yields:
A context manager that specifies the default device to use for newly
created ops.
-
"""
- # pylint: enable=line-too-long
- if (device_name_or_function is not None and
- not callable(device_name_or_function)):
- device_function = pydev.merge_device(device_name_or_function)
- else:
- device_function = device_name_or_function
-
+ self._add_device_to_stack(device_name_or_function, offset=2)
try:
- self._device_function_stack.append(device_function)
yield
finally:
- self._device_function_stack.pop()
+ self._device_function_stack.pop_obj()
def _apply_device_functions(self, op):
"""Applies the current device function stack to the given operation."""
- # Apply any device functions in reverse order, so that the most recently
+ # Apply any device functions in LIFO order, so that the most recently
# pushed function has the first chance to apply a device to the op.
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
- for device_function in reversed(self._device_function_stack):
- if device_function is None:
+ # pylint: disable=protected-access
+ for device_spec in self._device_function_stack.peek_objs():
+ if device_spec.function is None:
break
- op._set_device(device_function(op)) # pylint: disable=protected-access
+ op._set_device(device_spec.function(op))
+ op._device_code_locations = self._snapshot_device_function_stack_metadata()
+ # pylint: enable=protected-access
# pylint: disable=g-doc-return-or-yield
@tf_contextlib.contextmanager
@@ -4201,8 +4263,8 @@ class Graph(object):
yields the container name.
"""
original_container = self._container
+ self._container = container_name
try:
- self._container = container_name
yield self._container
finally:
self._container = original_container
@@ -4676,17 +4738,45 @@ class Graph(object):
if self._stack_state_is_thread_local:
# This may be called from a thread where device_function_stack doesn't yet
# exist.
+ # pylint: disable=protected-access
if not hasattr(self._thread_local, "_device_function_stack"):
- self._thread_local._device_function_stack = (
- self._graph_device_function_stack[:])
+ stack_copy_for_this_thread = self._graph_device_function_stack.copy()
+ self._thread_local._device_function_stack = stack_copy_for_this_thread
return self._thread_local._device_function_stack
+ # pylint: enable=protected-access
else:
return self._graph_device_function_stack
+ @property
+ def _device_functions_outer_to_inner(self):
+ user_device_specs = self._device_function_stack.peek_objs()
+ device_functions = [spec.function for spec in user_device_specs]
+ device_functions_outer_to_inner = list(reversed(device_functions))
+ return device_functions_outer_to_inner
+
+ def _snapshot_device_function_stack_metadata(self):
+ """Return device function stack as a list of TraceableObjects.
+
+ Returns:
+ [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
+ member is a displayable name for the user's argument to Graph.device, and
+ the filename and lineno members point to the code location where
+ Graph.device was called directly or indirectly by the user.
+ """
+ traceable_objects = self._device_function_stack.peek_traceable_objs()
+ snapshot = []
+ for obj in traceable_objects:
+ obj_copy = obj.copy_metadata()
+ obj_copy.obj = obj.obj.display_name
+ snapshot.append(obj_copy)
+ return snapshot
+
@_device_function_stack.setter
def _device_function_stack(self, device_function_stack):
if self._stack_state_is_thread_local:
+ # pylint: disable=protected-access
self._thread_local._device_function_stack = device_function_stack
+ # pylint: enable=protected-access
else:
self._graph_device_function_stack = device_function_stack
@@ -4696,12 +4786,12 @@ class Graph(object):
if self._stack_state_is_thread_local:
# This may be called from a thread where colocation_stack doesn't yet
# exist.
+ # pylint: disable=protected-access
if not hasattr(self._thread_local, "_colocation_stack"):
stack_copy_for_this_thread = self._graph_colocation_stack.copy()
- # pylint: disable=protected-access
self._thread_local._colocation_stack = stack_copy_for_this_thread
- # pylint: enable=protected-access
return self._thread_local._colocation_stack
+ # pylint: enable=protected-access
else:
return self._graph_colocation_stack
@@ -4713,7 +4803,9 @@ class Graph(object):
@_colocation_stack.setter
def _colocation_stack(self, colocation_stack):
if self._stack_state_is_thread_local:
+ # pylint: disable=protected-access
self._thread_local._colocation_stack = colocation_stack
+ # pylint: enable=protected-access
else:
self._graph_colocation_stack = colocation_stack
@@ -4882,8 +4974,8 @@ class _DefaultStack(threading.local):
@tf_contextlib.contextmanager
def get_controller(self, default):
"""A context manager for manipulating a default stack."""
+ self.stack.append(default)
try:
- self.stack.append(default)
yield default
finally:
# stack may be empty if reset() was called
@@ -5071,13 +5163,15 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
@tf_contextlib.contextmanager
def get_controller(self, default):
+ context.context().context_switches.push(
+ default.building_function, default.as_default)
try:
- context.context().context_switches.push(
- default.building_function, default.as_default)
with super(_DefaultGraphStack, self).get_controller(
default) as g, context.graph_mode():
yield g
finally:
+ # If an exception is raised here it may be hiding a related exception in
+ # the try-block (just above).
context.context().context_switches.pop()
@@ -5113,6 +5207,9 @@ def init_scope():
`init_scope` will simply install a fresh graph as the default one.
(3) The gradient tape is paused while the scope is active.
+
+ Raises:
+ RuntimeError: if graph state is incompatible with this initialization.
"""
# pylint: enable=g-doc-return-or-yield,line-too-long
@@ -5125,10 +5222,10 @@ def init_scope():
# the name scope of the current context.
default_graph = get_default_graph()
scope = default_graph.get_name_scope()
- if scope and scope[-1] != '/':
+ if scope and scope[-1] != "/":
# Names that end with trailing slashes are treated by `name_scope` as
# absolute.
- scope = scope + '/'
+ scope = scope + "/"
inner_device_stack = default_graph._device_function_stack # pylint: disable=protected-access
outer_context = None
@@ -5173,6 +5270,8 @@ def init_scope():
outer_graph._device_function_stack = inner_device_stack # pylint: disable=protected-access
yield
finally:
+ # If an exception is raised here it may be hiding a related exception in
+ # try-block (just above).
if outer_graph is not None:
outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index f848b69782..48328a7f58 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import gc
+import os
import threading
import weakref
@@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase):
self.assertEqual(3, flops_total.value)
+class DeviceStackTest(test_util.TensorFlowTestCase):
+
+ def testBasicDeviceAssignmentMetadata(self):
+
+ def device_func(unused_op):
+ return "/cpu:*"
+
+ const_zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.device(device_func):
+ const_three = constant_op.constant(3.0, name="three")
+
+ self.assertEqual(0, len(const_zero.op._device_assignments))
+
+ one_list = const_one.op._device_assignments
+ self.assertEqual(1, len(one_list))
+ self.assertEqual("/cpu", one_list[0].obj)
+ self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
+
+ two_list = const_two.op._device_assignments
+ self.assertEqual(2, len(two_list))
+ devices = [t.obj for t in two_list]
+ self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
+
+ three_list = const_three.op._device_assignments
+ self.assertEqual(1, len(three_list))
+ func_description = three_list[0].obj
+ expected_regex = r"device_func<.*ops_test.py, [0-9]+"
+ self.assertRegexpMatches(func_description, expected_regex)
+
+ def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
+
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.get_default_graph().device("/cpu"):
+ const_two = constant_op.constant([2.0], name="two")
+
+ one_metadata = const_one.op._device_assignments[0]
+ two_metadata = const_two.op._device_assignments[0]
+
+ # Verify both types of device assignment return the right stack info.
+ self.assertRegexpMatches("ops_test.py",
+ os.path.basename(one_metadata.filename))
+ self.assertEqual(one_metadata.filename, two_metadata.filename)
+ self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
+
+
class ColocationGroupTest(test_util.TensorFlowTestCase):
def testBasic(self):
@@ -2554,13 +2605,17 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
c.op.get_attr("_class")
- # Roughly test that stack information is being saved correctly for the op.
- locations_dict = b.op._colocation_dict
- self.assertIn("a", locations_dict)
- metadata = locations_dict["a"]
+ def testBasicColocationMetadata(self):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.colocate_with(const_two.op):
+ const_three = constant_op.constant(3.0, name="three")
+ locations_dict = const_three.op._colocation_dict
+ self.assertIn("two", locations_dict)
+ metadata = locations_dict["two"]
self.assertIsNone(metadata.obj)
- basename = metadata.filename.split("/")[-1]
- self.assertEqual("ops_test.py", basename)
+ # Check that this test's filename is recorded as the file containing the
+ # colocation statement.
+ self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
def testColocationDeviceInteraction(self):
with ops.device("/cpu:0"):
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 122c14c847..f983cbef04 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -73,7 +73,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
test_util.assert_equal_graph_def(def_57, def_75)
# Compare two unequal graphs
with self.assertRaisesRegexp(AssertionError,
- r"^Found unexpected node 'seven"):
+ r"^Found unexpected node '{{node seven}}"):
test_util.assert_equal_graph_def(def_57, def_empty)
def testIsGoogleCudaEnabled(self):
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index d1b9dc27bd..070d41147d 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -31,12 +31,14 @@ import time
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
@@ -716,6 +718,15 @@ class TensorBoard(Callback):
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
+
+ Raises:
+ ValueError: If histogram_freq is set and no validation data is provided.
+
+ @compatbility(eager)
+ Using `Tensorboard` callback will work while eager execution is enabled,
+ however outputting histogram summaries of weights and gradients is not
+ supported, and thus `histogram_freq` will be ignored.
+ @end_compatibility
"""
# pylint: enable=line-too-long
@@ -734,6 +745,11 @@ class TensorBoard(Callback):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
+ if self.histogram_freq and context.executing_eagerly():
+ logging.warning(
+ UserWarning('Weight and gradient histograms not supported for eager'
+ 'execution, setting `histogram_freq` to `0`.'))
+ self.histogram_freq = 0
self.merged = None
self.write_graph = write_graph
self.write_grads = write_grads
@@ -741,18 +757,22 @@ class TensorBoard(Callback):
self.batch_size = batch_size
self._current_batch = 0
self._total_batches_seen = 0
- # abstracted writer class to be able to stub for testing
- self._writer_class = tf_summary.FileWriter
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata
self.embeddings_data = embeddings_data
- def set_model(self, model):
- """Sets Keras model and creates summary ops."""
+ def _init_writer(self):
+ """Sets file writer."""
+ if context.executing_eagerly():
+ self.writer = summary_ops_v2.create_file_writer(self.log_dir)
+ elif self.write_graph:
+ self.writer = tf_summary.FileWriter(self.log_dir, K.get_session().graph)
+ else:
+ self.writer = tf_summary.FileWriter(self.log_dir)
- self.model = model
- self.sess = K.get_session()
+ def _make_histogram_ops(self, model):
+ """Defines histogram ops when histogram_freq > 0."""
# only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
@@ -793,8 +813,10 @@ class TensorBoard(Callback):
def is_indexed_slices(grad):
return type(grad).__name__ == 'IndexedSlices'
- grads = [grad.values if is_indexed_slices(grad) else grad
- for grad in grads]
+ grads = [
+ grad.values if is_indexed_slices(grad) else grad
+ for grad in grads
+ ]
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
@@ -803,12 +825,16 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
else:
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
- self.merged = tf_summary.merge_all()
- if self.write_graph:
- self.writer = self._writer_class(self.log_dir, self.sess.graph)
- else:
- self.writer = self._writer_class(self.log_dir)
+ def set_model(self, model):
+ """Sets Keras model and creates summary ops."""
+
+ self.model = model
+ self._init_writer()
+ # histogram summaries only enabled in graph mode
+ if not context.executing_eagerly():
+ self._make_histogram_ops(model)
+ self.merged = tf_summary.merge_all()
# If both embedding_freq and embeddings_data are available, we will
# visualize embeddings.
@@ -894,17 +920,24 @@ class TensorBoard(Callback):
"""
logs = logs or {}
- for name, value in logs.items():
- summary = tf_summary.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.writer.add_summary(summary, step)
+ if context.executing_eagerly():
+ # use v2 summary ops
+ with self.writer.as_default(), summary_ops_v2.always_record_summaries():
+ for name, value in logs.items():
+ summary_ops_v2.scalar(name, value.item(), step=step)
+ else:
+ # use FileWriter from v1 summary
+ for name, value in logs.items():
+ summary = tf_summary.Summary()
+ summary_value = summary.value.add()
+ summary_value.simple_value = value.item()
+ summary_value.tag = name
+ self.writer.add_summary(summary, step)
self.writer.flush()
def on_train_begin(self, logs=None):
"""Checks if histogram summaries can be run."""
-
+ # will never be set when in eager
if self.histogram_freq:
if 'validation_steps' in self.params:
self._validation_batches = self.params['validation_steps']
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 7d830078ce..d38a753263 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -22,6 +22,7 @@ import csv
import os
import re
import shutil
+import tempfile
import threading
import unittest
@@ -29,10 +30,12 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import adam
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -63,7 +66,7 @@ class KerasCallbacksTest(test.TestCase):
np.random.seed(1337)
temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
filepath = os.path.join(temp_dir, 'checkpoint.h5')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -479,7 +482,7 @@ class KerasCallbacksTest(test.TestCase):
with self.test_session():
np.random.seed(1337)
temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
filepath = os.path.join(temp_dir, 'log.tsv')
sep = '\t'
@@ -557,7 +560,7 @@ class KerasCallbacksTest(test.TestCase):
# does not result in invalid CSVs.
np.random.seed(1337)
tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
+ self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
with self.test_session():
fp = os.path.join(tmpdir, 'test.csv')
@@ -649,7 +652,7 @@ class KerasCallbacksTest(test.TestCase):
np.random.seed(1337)
temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -747,7 +750,7 @@ class KerasCallbacksTest(test.TestCase):
def test_TensorBoard_histogram_freq_must_have_validation_data(self):
np.random.seed(1337)
tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
+ self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
with self.test_session():
filepath = os.path.join(tmpdir, 'logs')
@@ -819,7 +822,7 @@ class KerasCallbacksTest(test.TestCase):
def test_TensorBoard_multi_input_output(self):
np.random.seed(1337)
tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
+ self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
with self.test_session():
filepath = os.path.join(tmpdir, 'logs')
@@ -917,9 +920,12 @@ class KerasCallbacksTest(test.TestCase):
def close(self):
pass
+ def _init_writer(obj):
+ obj.writer = FileWriterStub(obj.log_dir)
+
np.random.seed(1337)
tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
+ self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
@@ -940,13 +946,13 @@ class KerasCallbacksTest(test.TestCase):
loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
+ keras.callbacks.TensorBoard._init_writer = _init_writer
tsb = keras.callbacks.TensorBoard(
log_dir=tmpdir,
histogram_freq=1,
write_images=True,
write_grads=True,
batch_size=5)
- tsb._writer_class = FileWriterStub
cbks = [tsb]
# fit with validation data
@@ -964,7 +970,7 @@ class KerasCallbacksTest(test.TestCase):
def test_Tensorboard_histogram_summaries_with_generator(self):
np.random.seed(1337)
tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
+ self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
def generator():
x = np.random.randn(10, 100).astype(np.float32)
@@ -1061,7 +1067,7 @@ class KerasCallbacksTest(test.TestCase):
def test_TensorBoard_with_ReduceLROnPlateau(self):
with self.test_session():
temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -1118,11 +1124,11 @@ class KerasCallbacksTest(test.TestCase):
def close(self):
pass
- logdir = 'fake_dir'
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- # log every batch
- tb_cbk = keras.callbacks.TensorBoard(logdir)
- tb_cbk.writer = FileWriterStub(logdir)
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir)
+ tb_cbk.writer = FileWriterStub(temp_dir)
for batch in range(5):
tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)})
@@ -1150,10 +1156,11 @@ class KerasCallbacksTest(test.TestCase):
def close(self):
pass
- logdir = 'fake_dir'
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- tb_cbk = keras.callbacks.TensorBoard(logdir)
- tb_cbk.writer = FileWriterStub(logdir)
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir)
+ tb_cbk.writer = FileWriterStub(temp_dir)
tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)})
tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)})
@@ -1164,6 +1171,43 @@ class KerasCallbacksTest(test.TestCase):
self.assertEqual(epoch_step, 0)
self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
+ @test_util.run_in_graph_and_eager_modes
+ def test_Tensorboard_eager(self):
+ with self.test_session():
+ temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=adam.AdamOptimizer(0.01),
+ metrics=['accuracy'])
+
+ cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)]
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=2,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(temp_dir))
+
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index b41f6ee03b..e1214f8103 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -26,6 +26,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -174,6 +175,12 @@ class Layer(checkpointable.CheckpointableBase):
self.supports_masking = False
+ call_argspec = tf_inspect.getargspec(self.call)
+ if 'training' in call_argspec.args:
+ self._expects_training_arg = True
+ else:
+ self._expects_training_arg = False
+
# Manage input shape information if passed.
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer
@@ -786,17 +793,8 @@ class Layer(checkpointable.CheckpointableBase):
if hasattr(self, '_initial_weights') and self._initial_weights is not None:
self.set_weights(self._initial_weights)
del self._initial_weights
- self._post_build_cleanup()
return outputs
- def _post_build_cleanup(self):
- """Hooks to run after all sub-Layers are built."""
- # Note that in addition to Layer.__call__, this method is called by Model
- # after building a graph network (which skips __call__). It should be called
- # when possible if self.built may have switched from False to True, and is
- # idempotent.
- pass # No-op for Layers which don't override this method.
-
def apply(self, inputs, *args, **kwargs):
"""Apply the layer on a input.
@@ -970,6 +968,39 @@ class Layer(checkpointable.CheckpointableBase):
Returns:
An input shape tuple.
"""
+ if context.executing_eagerly():
+ # In this case we build the model first in order to do shape inference.
+ # This is acceptable because the framework only calls
+ # `compute_output_shape` on shape values that the layer would later be
+ # built for. It would however cause issues in case a user attempts to
+ # use `compute_output_shape` manually (these users will have to
+ # implement `compute_output_shape` themselves).
+ self.build(input_shape)
+
+ with context.graph_mode():
+ graph = eager_function.CapturingGraph()
+ with graph.as_default():
+ if isinstance(input_shape, list):
+ inputs = [generate_placeholders_from_shape(shape)
+ for shape in input_shape]
+ else:
+ inputs = generate_placeholders_from_shape(input_shape)
+
+ try:
+ if self._expects_training_arg:
+ outputs = self(inputs, training=False)
+ else:
+ outputs = self(inputs)
+ except TypeError:
+ raise NotImplementedError('We could not automatically infer '
+ 'the static shape of the layer\'s output.'
+ ' Please implement the '
+ '`compute_output_shape` method on your '
+ 'layer (%s).' % self.__class__.__name__)
+ if isinstance(outputs, list):
+ return [output.shape for output in outputs]
+ else:
+ return outputs.shape
raise NotImplementedError
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
@@ -1925,3 +1956,18 @@ def make_variable(name,
synchronization=synchronization,
aggregation=aggregation)
return v
+
+
+def generate_dummy_data_from_shape(shape):
+ if isinstance(shape, tensor_shape.TensorShape):
+ shape = shape.as_list()
+
+ # Replace Nones in input shape with dummy `1` value
+ shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
+ for x in shape]
+ shape = [1 if x is None else x for x in shape]
+ return array_ops.ones(shape, dtype=backend.floatx())
+
+
+def generate_placeholders_from_shape(shape):
+ return array_ops.placeholder(shape=shape, dtype=backend.floatx())
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 752e9963ca..20a29dbf20 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import copy
-import functools
import json
import os
import weakref
@@ -30,6 +29,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -144,10 +144,6 @@ class Network(base_layer.Layer):
self._checkpointable_saver = checkpointable_utils.CheckpointableSaver(
weakref.ref(self))
- # A zero-argument function which should be called and set back to None as
- # soon as the network is built (only applicable to subclassed Models). Runs
- # restore operations when graph building.
- self._in_progress_restore_finalizer = None
@checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
@@ -739,6 +735,86 @@ class Network(base_layer.Layer):
return specs[0]
return specs
+ def build(self, input_shape):
+ """Builds the model based on input shapes received.
+
+ This is to be used for subclassed models, which do not know at instantiation
+ time what their inputs look like.
+
+ Args:
+ input_shape: Single tuple, TensorShape, or list of shapes, where shapes
+ are tuples, integers, or TensorShapes.
+
+ Raises:
+ ValueError:
+ 1. In case of invalid user-provided data (not of type tuple,
+ list, or TensorShape).
+ 2. If the model requires call arguments that are agnostic
+ to the input shapes (positional or kwarg in call signature).
+ 3. If not all layers were properly built.
+ 4. If float type inputs are not supported within the layers.
+
+ In each of these cases, the user should build their model by calling it
+ on real tensor data.
+ """
+ if self._is_graph_network:
+ self.built = True
+ return
+
+ # If subclass network
+ if input_shape is None:
+ raise ValueError('Input shape must be defined when calling build on a '
+ 'model subclass network.')
+ valid_types = (tuple, list, tensor_shape.TensorShape)
+ if not isinstance(input_shape, valid_types):
+ raise ValueError('Specified input shape is not one of the valid types. '
+ 'Please specify a batch input shape of type tuple or '
+ 'list of input shapes. User provided '
+ 'input type: {}'.format(type(input_shape)))
+
+ if input_shape and not self.inputs:
+ if isinstance(input_shape, list):
+ # List of input shapes
+ x = [base_layer.generate_dummy_data_from_shape(shape)
+ for shape in input_shape]
+ else:
+ x = base_layer.generate_dummy_data_from_shape(input_shape)
+
+ kwargs = {}
+ num_call_args = len(tf_inspect.getargspec(self.call).args)
+ if self._expects_training_arg and num_call_args == 3:
+ # Has call signature of call(self, input, training)
+ kwargs['training'] = False
+ elif num_call_args > 2:
+ # Has invalid call signature of call(self, input, *args, **kwargs)
+ raise ValueError('Currently, you cannot build your model if it has '
+ 'positional or keyword arguments that are not '
+ 'inputs to the model, but are required for its '
+ '`call` method. Instead, in order to instantiate '
+ 'and build your model, `call` your model on real '
+ 'tensor data with all expected call arguments.')
+
+ try:
+ self.call(x, **kwargs)
+ except (errors.InvalidArgumentError, TypeError):
+ raise ValueError('You cannot build your model by calling `build` '
+ 'if your layers do not support float type inputs. '
+ 'Instead, in order to instantiate and build your '
+ 'model, `call` your model on real tensor data (of '
+ 'the correct dtype).')
+
+ if self._layers:
+ self._track_layers(self._layers)
+ if self.layers:
+ for layer in self.layers:
+ if not layer.built:
+ raise ValueError('Layer: {} was not built in your model. Calling '
+ '`build` manually on a subclassed model is only '
+ 'allowed for models with a static topology. '
+ 'In this case, you can build your model by '
+ 'calling it on real tensor data.'.format(layer))
+ self.built = True
+
def call(self, inputs, training=None, mask=None):
"""Calls the model on new inputs.
@@ -779,6 +855,8 @@ class Network(base_layer.Layer):
def compute_output_shape(self, input_shape):
if not self._is_graph_network:
+ if context.executing_eagerly():
+ return super(Network, self).compute_output_shape(input_shape)
raise NotImplementedError
if isinstance(input_shape, list):
@@ -889,7 +967,7 @@ class Network(base_layer.Layer):
mask: List of masks (tensors or None).
Returns:
- Three lists: output_tensors, output_masks, output_shapes
+ Two lists: output_tensors, output_masks
"""
# Note: masking support is relevant mainly for Keras.
# It cannot be factored out without having the fully reimplement the network
@@ -956,7 +1034,6 @@ class Network(base_layer.Layer):
else:
output_masks = [None for _ in output_tensors]
computed_tensors = [computed_tensor]
- computed_masks = [computed_mask]
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
@@ -1423,13 +1500,9 @@ class Network(base_layer.Layer):
'load_weights).')
if not context.executing_eagerly():
session = backend.get_session()
- finalizer = functools.partial(status.run_restore_ops, session=session)
- if self.built:
- finalizer()
- else:
- # Hold on to this status object until the network is built (for
- # subclassed Models). Then we'll run restore ops if necessary.
- self._in_progress_restore_finalizer = finalizer
+ # Restore existing variables (if any) immediately, and set up a
+ # streaming restore for any variables created in the future.
+ checkpointable_utils.streaming_restore(status=status, session=session)
return status
if h5py is None:
raise ImportError(
@@ -1447,14 +1520,6 @@ class Network(base_layer.Layer):
else:
saving.load_weights_from_hdf5_group(f, self.layers)
- def _post_build_cleanup(self):
- super(Network, self)._post_build_cleanup()
- if self._in_progress_restore_finalizer is not None:
- # Runs queued restore operations left over from load_weights when graph
- # building.
- self._in_progress_restore_finalizer()
- self._in_progress_restore_finalizer = None
-
def _updated_config(self):
"""Util shared between different serialization methods.
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 030328f2a6..e029e614e0 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -722,18 +722,23 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.assertEqual(len(graph.get_operations()), op_count)
def _weight_loading_test_template(self, make_model_fn):
- with self.test_session() as session:
+ with self.test_session():
model = make_model_fn()
+ model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
+ train_x = np.random.random((3, 2))
+ train_y = np.random.random((3,))
+ x = constant_op.constant(train_x, dtype=dtypes.float32)
- x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
- executing_eagerly = context.executing_eagerly()
- ref_y_tensor = model(x)
- if not executing_eagerly:
- session.run([v.initializer for v in model.variables])
- ref_y = self.evaluate(ref_y_tensor)
+ model.train_on_batch(train_x, train_y)
model.save_weights(prefix, save_format='tf')
+ ref_y_before_train = model.predict(train_x)
+ model.train_on_batch(train_x, train_y)
+ ref_y_after_train = model.predict(train_x)
for v in model.variables:
self.evaluate(
v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
@@ -741,16 +746,27 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.addCleanup(shutil.rmtree, temp_dir)
model.load_weights(prefix)
- y = self.evaluate(model(x))
- self.assertAllClose(ref_y, y)
+ self.assertAllClose(ref_y_before_train, self.evaluate(model(x)))
# Test restore-on-create if this is a subclassed Model (graph Networks
# will have already created their variables).
load_model = make_model_fn()
load_model.load_weights(prefix)
- restore_on_create_y_tensor = load_model(x)
- restore_on_create_y = self.evaluate(restore_on_create_y_tensor)
- self.assertAllClose(ref_y, restore_on_create_y)
+ self.assertAllClose(
+ ref_y_before_train,
+ self.evaluate(load_model(x)))
+ load_model = make_model_fn()
+ load_model.load_weights(prefix)
+ # We need to run some of the restore ops for predict(), but not all
+ # variables have been created yet (optimizer slot variables). Tests
+ # incremental restore.
+ load_model.predict(train_x)
+ load_model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+ load_model.train_on_batch(train_x, train_y)
+ self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
@test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model(self):
@@ -858,5 +874,6 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 3eb69bd7f3..34f74db6ef 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -110,7 +110,6 @@ class TopologyConstructionTest(test.TestCase):
layer = keras.layers.BatchNormalization()
_ = layer.apply(x1)
- print('BN updates', layer._updates)
self.assertEqual(len(layer.updates), 2)
self.assertEqual(len(layer.get_updates_for(x1)), 2)
self.assertEqual(len(layer.get_updates_for(None)), 0)
@@ -960,9 +959,6 @@ class DeferredModeTest(test.TestCase):
def call(self, inputs):
return inputs[0] + inputs[1]
- def compute_output_shape(self, input_shape):
- return input_shape[0]
-
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = keras.layers.Dense(2)(c)
@@ -978,6 +974,101 @@ class DeferredModeTest(test.TestCase):
self.assertEqual(outputs[1].shape.as_list(), [10, 2])
+class DefaultShapeInferenceBehaviorTest(test.TestCase):
+
+ def _testShapeInference(self, model, input_shape, expected_output_shape):
+ input_value = np.random.random(input_shape)
+ output_value = model.predict(input_value)
+ self.assertEqual(output_value.shape, expected_output_shape)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testSingleInputCase(self):
+
+ class LayerWithOneInput(keras.layers.Layer):
+
+ def build(self, input_shape):
+ self.w = array_ops.ones(shape=(3, 4))
+
+ def call(self, inputs):
+ return keras.backend.dot(inputs, self.w)
+
+ inputs = input_layer_lib.Input(shape=(3,))
+ layer = LayerWithOneInput()
+
+ if context.executing_eagerly():
+ self.assertEqual(
+ layer.compute_output_shape((None, 3)).as_list(), [None, 4])
+ # As a side-effect, compute_output_shape builds the layer.
+ self.assertTrue(layer.built)
+ # We can still query the layer's compute_output_shape with compatible
+ # input shapes.
+ self.assertEqual(
+ layer.compute_output_shape((6, 3)).as_list(), [6, 4])
+
+ outputs = layer(inputs)
+ model = keras.Model(inputs, outputs)
+ self._testShapeInference(model, (2, 3), (2, 4))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testMultiInputOutputCase(self):
+
+ class MultiInputOutputLayer(keras.layers.Layer):
+
+ def build(self, input_shape):
+ self.w = array_ops.ones(shape=(3, 4))
+
+ def call(self, inputs):
+ a = keras.backend.dot(inputs[0], self.w)
+ b = a + inputs[1]
+ return [a, b]
+
+ input_a = input_layer_lib.Input(shape=(3,))
+ input_b = input_layer_lib.Input(shape=(4,))
+ output_a, output_b = MultiInputOutputLayer()([input_a, input_b])
+ model = keras.Model([input_a, input_b], [output_a, output_b])
+ output_a_val, output_b_val = model.predict(
+ [np.random.random((2, 3)), np.random.random((2, 4))])
+ self.assertEqual(output_a_val.shape, (2, 4))
+ self.assertEqual(output_b_val.shape, (2, 4))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrainingArgument(self):
+
+ class LayerWithTrainingArg(keras.layers.Layer):
+
+ def build(self, input_shape):
+ self.w = array_ops.ones(shape=(3, 4))
+
+ def call(self, inputs, training):
+ return keras.backend.dot(inputs, self.w)
+
+ inputs = input_layer_lib.Input(shape=(3,))
+ outputs = LayerWithTrainingArg()(inputs, training=False)
+ model = keras.Model(inputs, outputs)
+ self._testShapeInference(model, (2, 3), (2, 4))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testUnsupportedSignature(self):
+
+ class LayerWithAdditionalArg(keras.layers.Layer):
+
+ def build(self, input_shape):
+ self.w = array_ops.ones(shape=(3, 4))
+
+ def call(self, inputs, some_arg):
+ return keras.backend.dot(inputs, self.w) + some_arg
+
+ inputs = input_layer_lib.Input(shape=(3,))
+ if context.executing_eagerly():
+ with self.assertRaises(NotImplementedError):
+ outputs = LayerWithAdditionalArg()(inputs, some_arg=0)
+ else:
+ # Works with graph mode because the graph of ops is built together with
+ # the graph of layers.
+ outputs = LayerWithAdditionalArg()(inputs, some_arg=0)
+ _ = keras.Model(inputs, outputs)
+
+
class GraphUtilsTest(test.TestCase):
def testGetReachableFromInputs(self):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 4df739254b..0fe14e99e0 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -24,14 +24,11 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
-from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_arrays
@@ -40,11 +37,9 @@ from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import slice_arrays
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -374,21 +369,14 @@ class Model(Network):
'sample_weight_mode dictionary: "' + name + '". '
'Only expected the following keys: ' + str(self.output_names))
for i, name in enumerate(self.output_names):
- if i in skip_target_weighing_indices:
- weight = None
- sample_weight_modes.append(None)
- else:
- if name not in sample_weight_mode:
- raise ValueError(
- 'Output "' + name + '" missing from sample_weight_modes '
- 'dictionary')
- if sample_weight_mode.get(name) == 'temporal':
- weight = K.placeholder(ndim=2, name=name + '_sample_weights')
- sample_weight_modes.append('temporal')
- else:
- weight = K.placeholder(ndim=1, name=name + 'sample_weights')
- sample_weight_modes.append(None)
+ if (i not in skip_target_weighing_indices and
+ name not in sample_weight_mode):
+ raise ValueError('Output "' + name +
+ '" missing from sample_weight_modes dictionary')
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode.get(name), name, i)
sample_weights.append(weight)
+ sample_weight_modes.append(mode)
elif isinstance(sample_weight_mode, list):
if len(sample_weight_mode) != len(self.outputs):
raise ValueError('When passing a list as sample_weight_mode, '
@@ -396,36 +384,17 @@ class Model(Network):
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed '
'sample_weight_mode=' + str(sample_weight_mode))
- for i in range(len(self.output_names)):
- if i in skip_target_weighing_indices:
- weight = None
- sample_weight_modes.append(None)
- else:
- mode = sample_weight_mode[i]
- name = self.output_names[i]
- if mode == 'temporal':
- weight = K.placeholder(ndim=2, name=name + '_sample_weights')
- sample_weight_modes.append('temporal')
- else:
- weight = K.placeholder(ndim=1, name=name + '_sample_weights')
- sample_weight_modes.append(None)
+ for i, name in enumerate(self.output_names):
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode[i], name, i)
sample_weights.append(weight)
+ sample_weight_modes.append(mode)
else:
for i, name in enumerate(self.output_names):
- if i in skip_target_weighing_indices:
- sample_weight_modes.append(None)
- sample_weights.append(None)
- else:
- if sample_weight_mode == 'temporal':
- sample_weights.append(array_ops.placeholder_with_default(
- constant_op.constant([[1.]], dtype=K.floatx()),
- shape=[None, None], name=name + '_sample_weights'))
- sample_weight_modes.append('temporal')
- else:
- sample_weights.append(array_ops.placeholder_with_default(
- constant_op.constant([1.], dtype=K.floatx()),
- shape=[None], name=name + '_sample_weights'))
- sample_weight_modes.append(None)
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode, name, i)
+ sample_weights.append(weight)
+ sample_weight_modes.append(mode)
self.sample_weight_modes = sample_weight_modes
self._feed_sample_weight_modes = []
for i in range(len(self.outputs)):
@@ -488,43 +457,21 @@ class Model(Network):
weights = sample_weights[i]
output_metrics = nested_metrics[i]
output_weighted_metrics = nested_weighted_metrics[i]
+ output_shape = self.outputs[i].get_shape().as_list()
+ loss_fn = self.loss_functions[i]
- def handle_metrics(metrics, weights=None):
+ def handle_metrics(metrics, output_shape, loss_fn, weights=None):
+ """Invokes metric functions for the output."""
for metric in metrics:
- if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
- # custom handling of accuracy/crossentropy
- # (because of class mode duality)
- output_shape = self.outputs[i].get_shape().as_list()
- if (output_shape[-1] == 1 or
- self.loss_functions[i] == losses.binary_crossentropy):
- # case: binary accuracy/crossentropy
- if metric in ('accuracy', 'acc'):
- metric_fn = metrics_module.binary_accuracy
- elif metric in ('crossentropy', 'ce'):
- metric_fn = metrics_module.binary_crossentropy
- elif self.loss_functions[
- i] == losses.sparse_categorical_crossentropy:
- # case: categorical accuracy/crossentropy with sparse targets
- if metric in ('accuracy', 'acc'):
- metric_fn = metrics_module.sparse_categorical_accuracy
- elif metric in ('crossentropy', 'ce'):
- metric_fn = metrics_module.sparse_categorical_crossentropy
- else:
- # case: categorical accuracy/crossentropy
- if metric in ('accuracy', 'acc'):
- metric_fn = metrics_module.categorical_accuracy
- elif metric in ('crossentropy', 'ce'):
- metric_fn = metrics_module.categorical_crossentropy
- weighted_metric_fn = training_utils.weighted_masked_objective(
- metric_fn)
- else:
- metric_fn = metrics_module.get(metric)
- weighted_metric_fn = training_utils.weighted_masked_objective(
- metric_fn)
- metric_name = training_utils.get_base_metric_name(
+ metric_fn = training_utils.get_metric_function(
+ metric, output_shape=output_shape, loss_fn=loss_fn)
+ metric_name = training_utils.get_metric_name(
metric, weighted=weights is not None)
+
with K.name_scope(metric_name):
+ weighted_metric_fn = training_utils.weighted_masked_objective(
+ metric_fn)
metric_result = weighted_metric_fn(
y_true, y_pred, weights=weights, mask=masks[i])
@@ -538,8 +485,9 @@ class Model(Network):
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
- handle_metrics(output_metrics)
- handle_metrics(output_weighted_metrics, weights=weights)
+ handle_metrics(output_metrics, output_shape, loss_fn)
+ handle_metrics(
+ output_weighted_metrics, output_shape, loss_fn, weights=weights)
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -562,95 +510,6 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
- def build(self, input_shape):
- """Build the model based on input shapes received.
-
- This is to be used for subclassed models, which do not know at instantiation
- time what their inputs look like.
-
- Args:
- input_shape: Single tuple, TensorShape, or list of shapes, where shapes
- are tuples, integers, or TensorShapes.
-
- Raises:
- ValueError:
- 1. In case of invalid user-provided data (not of type tuple,
- list, or TensorShape).
- 2. If the model requires call arguments that are agnostic
- to the input shapes (positional or kwarg in call signature).
- 3. If not all layers were properly built.
- 4. If float type inputs are not supported within the layers.
-
- In each of these cases, the user should build their model by calling it
- on real tensor data.
- """
- if self._is_graph_network:
- self.built = True
- return
-
- # If subclass network
- if input_shape is None:
- raise ValueError('Input shape must be defined when calling build on a '
- 'model subclass network.')
- valid_types = (tuple, list, tensor_shape.TensorShape)
- if not isinstance(input_shape, valid_types):
- raise ValueError('Specified input shape is not one of the valid types. '
- 'Please specify a batch input shape of type tuple or '
- 'list of input shapes. User provided '
- 'input type: {}'.format(type(input_shape)))
-
- def _generate_dummy_data_from_shape(shape):
- if isinstance(shape, tensor_shape.TensorShape):
- shape = shape.as_list()
-
- # Replace Nones in input shape with dummy `1` value
- shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
- for x in shape]
- shape = [1 if x is None else x for x in shape]
- return array_ops.ones(shape, dtype=K.floatx())
-
- if input_shape and not self.inputs:
- if isinstance(input_shape, list):
- # List of input shapes
- x = [_generate_dummy_data_from_shape(shape) for shape in input_shape]
- else:
- x = _generate_dummy_data_from_shape(input_shape)
-
- kwargs = {}
- num_call_args = len(tf_inspect.getargspec(self.call).args)
- if self._expects_training_arg and num_call_args == 3:
- # Has call signature of call(self, input, training)
- kwargs['training'] = False
- elif num_call_args > 2:
- # Has invalid call signature of call(self, input, *args, **kwargs)
- raise ValueError('Currently, you cannot build your model if it has '
- 'positional or keyword arguments that are not '
- 'inputs to the model, but are required for its '
- '`call` method. Instead, in order to instantiate '
- 'and build your model, `call` your model on real '
- 'tensor data with all expected call arguments.')
-
- try:
- self.call(x, **kwargs)
- except (errors.InvalidArgumentError, TypeError):
- raise ValueError('You cannot build your model by calling `build` '
- 'if your layers do not support float type inputs. '
- 'Instead, in order to instantiate and build your '
- 'model, `call` your model on real tensor data (of '
- 'the correct dtype).')
-
- if self._layers:
- self._track_layers(self._layers)
- if self.layers:
- for layer in self.layers:
- if not layer.built:
- raise ValueError('Layer: {} was not built in your model. Calling '
- '`build` manually on a subclassed model is only '
- 'allowed for models with a static topology. '
- 'In this case, you can build your model by '
- 'calling it on real tensor data.'.format(layer))
- self.built = True
-
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -698,7 +557,6 @@ class Model(Network):
updates=updates,
name='train_function',
**self._function_kwargs)
- self._post_build_cleanup()
def _make_test_function(self):
if not hasattr(self, 'test_function'):
@@ -716,7 +574,6 @@ class Model(Network):
updates=self.state_updates + self.metrics_updates,
name='test_function',
**self._function_kwargs)
- self._post_build_cleanup()
def _make_predict_function(self):
if not hasattr(self, 'predict_function'):
@@ -735,7 +592,6 @@ class Model(Network):
updates=self.state_updates,
name='predict_function',
**kwargs)
- self._post_build_cleanup()
def _get_iterator_get_next_tensors(self, iterator):
get_next_op = self._iterator_get_next.get(iterator, None)
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index adefffab11..6572e2c344 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -50,7 +50,6 @@ def fit_loop(model,
val_targets=None,
val_sample_weights=None,
shuffle=True,
- callback_metrics=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None):
@@ -69,8 +68,6 @@ def fit_loop(model,
val_targets: List of target arrays.
val_sample_weights: Optional list of sample weight arrays.
shuffle: Whether to shuffle the data at the beginning of each epoch
- callback_metrics: List of strings, the display names of the metrics
- passed to the callbacks. They should be the
concatenation of list the display names of the outputs of
`f` and the list of display names of the outputs of `f_val`.
initial_epoch: Epoch at which to start training
@@ -121,9 +118,7 @@ def fit_loop(model,
out_labels = model.metrics_names
if do_validation:
- callback_metrics = copy.copy(out_labels) + [
- 'val_' + n for n in out_labels
- ]
+ callback_metrics = copy.copy(out_labels) + ['val_' + n for n in out_labels]
# need to create the test_function before start of the first epoch
# because TensorBoard callback on_epoch_begin adds summary to the
# list of fetches of the test_function
@@ -197,9 +192,7 @@ def fit_loop(model,
if steps_per_epoch is not None:
# Step-wise fit loop.
for step_index in range(steps_per_epoch):
- batch_logs = {}
- batch_logs['batch'] = step_index
- batch_logs['size'] = 1
+ batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
outs = f(ins)
@@ -388,7 +381,9 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None):
return outs
-def test_loop(model, inputs, targets,
+def test_loop(model,
+ inputs,
+ targets,
sample_weights=None,
batch_size=None,
verbose=0,
@@ -485,8 +480,7 @@ def test_loop(model, inputs, targets,
if isinstance(batch_outs, list):
if batch_index == 0:
- for batch_out in enumerate(batch_outs):
- outs.append(0.)
+ outs.extend([0.] * len(batch_outs))
for i, batch_out in enumerate(batch_outs):
if i in stateful_metric_indices:
outs[i] = batch_out
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 397de42985..0b25b827ad 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -30,35 +30,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as cbks
-from tensorflow.python.keras import losses
-from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import tf_logging as logging
-def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
- if metric == 'accuracy' or metric == 'acc':
- # custom handling of accuracy
- # (because of class mode duality)
- output_shape = internal_output_shapes
- if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
- # case: binary accuracy
- acc_fn = metrics_module.binary_accuracy
- elif loss_func == losses.sparse_categorical_crossentropy:
- # case: categorical accuracy with sparse targets
- acc_fn = metrics_module.sparse_categorical_accuracy
- else:
- acc_fn = metrics_module.categorical_accuracy
-
- metric_name = 'acc'
- return metric_name, acc_fn
- else:
- metric_fn = metrics_module.get(metric)
- metric_name = metric_fn.__name__
- return metric_name, metric_fn
-
-
def _eager_loss_fn(outputs, targets, loss_fn, output_name):
with backend.name_scope(output_name + '_loss'):
loss = loss_fn(targets, outputs)
@@ -74,9 +50,8 @@ def _eager_metrics_fn(model, outputs, targets):
targets: The predictions or targets of the given model.
Returns:
- Returns the metric names and metric results for each output of the model.
+ Returns the metric results for each output of the model.
"""
- metric_names = []
metric_results = []
if not isinstance(outputs, list):
outputs = [outputs]
@@ -87,18 +62,15 @@ def _eager_metrics_fn(model, outputs, targets):
for i in range(len(model.outputs)):
output_metrics = model.nested_metrics[i]
for nested_output_metric in output_metrics:
- metric_name, metric_fn = _get_metrics_info(
+ metric_fn = training_utils.get_metric_function(
nested_output_metric, backend.int_shape(model.outputs[i]),
model.loss_functions[i])
-
- if len(model.output_names) > 1:
- metric_name = model.output_names[i] + '_' + metric_name
- if metric_name not in model.metrics_names:
- model.metrics_names.append(metric_name)
+ # weighted metrics are not supported in eager mode
+ metric_name = training_utils.get_metric_name(
+ nested_output_metric, weighted=False)
with backend.name_scope(metric_name):
metric_result = metric_fn(targets[i], outputs[i])
- metric_names.append(metric_name)
metric_results.append(backend.mean(metric_result))
return metric_results
@@ -617,7 +589,6 @@ def fit_loop(model,
verbose=1,
callbacks=None,
shuffle=True,
- callback_metrics=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None):
@@ -639,10 +610,6 @@ def fit_loop(model,
verbose: Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
shuffle: Whether to shuffle the data at the beginning of each epoch
- callback_metrics: List of strings, the display names of the metrics
- passed to the callbacks. They should be the
- concatenation of list the display names of the outputs of
- `f` and the list of display names of the outputs of `f_val`.
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
@@ -674,6 +641,7 @@ def fit_loop(model,
num_train_samples = None
out_labels = None
+ callback_metrics = None
if model._is_compiled:
out_labels = model.metrics_names
if do_validation:
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index bdb3035129..b0f57f0770 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -31,284 +31,6 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class TrainingTest(test.TestCase):
- def test_fit_on_arrays(self):
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
-
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
-
- # Test fit at different verbosity
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=2)
-
- # Test with validation data
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=2)
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- # Test with validation split
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=0,
- validation_split=0.2)
-
- # Test with dictionary inputs
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
- validation_data=({'input_a': input_a_np,
- 'input_b': input_b_np
- },
- {
- 'dense': output_d_np,
- 'dropout': output_e_np
- }),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.train_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np})
- # Test with lists for loss, metrics
- loss = ['mae', 'mse']
- metrics = ['acc', 'mae']
- model.compile(optimizer, loss, metrics=metrics)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
-
- # Test with dictionaries for loss, metrics, loss weights
- loss = {'dense': 'mse', 'dropout': 'mae'}
- loss_weights = {'dense': 1., 'dropout': 0.5}
- metrics = {'dense': 'mse', 'dropout': 'mae'}
- model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
-
- # Invalid use cases
- with self.assertRaises(AttributeError):
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- validation_data=([input_a_np, input_b_np], 0, 0),
- verbose=0)
- with self.assertRaises(ValueError):
- model.train_on_batch({'input_a': input_a_np},
- [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- model.train_on_batch([input_a_np], [output_d_np, output_e_np])
- with self.assertRaises(AttributeError):
- model.train_on_batch(1, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- model.train_on_batch(input_a_np, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_input = np.random.random((11, 3))
- model.train_on_batch([bad_input, input_b_np],
- [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_target = np.random.random((11, 4))
- model.train_on_batch([input_a_np, input_b_np],
- [bad_target, output_e_np])
-
- # Build single-input model
- x = keras.layers.Input(shape=(3,), name='input_a')
- y = keras.layers.Dense(4)(x)
- model = keras.models.Model(x, y)
- model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
- # This will work
- model.fit([input_a_np], output_d_np, epochs=1)
- with self.assertRaises(ValueError):
- model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
-
- def test_evaluate_predict_on_arrays(self):
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- metrics = ['acc', 'mae']
- model.compile(
- optimizer,
- loss,
- metrics=metrics,
- loss_weights=loss_weights,
- sample_weight_mode=None)
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
-
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
-
- # Test evaluate at different verbosity
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=0)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=1)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=2)
- self.assertEqual(len(out), 7)
- out = model.test_on_batch([input_a_np, input_b_np],
- [output_d_np, output_e_np])
- self.assertEqual(len(out), 7)
-
- # Test evaluate with dictionary inputs
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
- batch_size=5,
- verbose=0)
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
- batch_size=5,
- verbose=1)
-
- # Test predict
- out = model.predict([input_a_np, input_b_np], batch_size=5)
- self.assertEqual(len(out), 2)
- out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
- self.assertEqual(len(out), 2)
- out = model.predict_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- })
- self.assertEqual(len(out), 2)
-
- def test_invalid_loss_or_metrics(self):
- num_classes = 5
- train_samples = 1000
- test_samples = 1000
- input_dim = 5
-
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- model.compile(loss='categorical_crossentropy',
- optimizer=RMSPropOptimizer(learning_rate=0.001))
- np.random.seed(1337)
-
- (x_train, y_train), (_, _) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
-
- with self.assertRaises(ValueError):
- model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
-
- with self.assertRaises(TypeError):
- model.compile(loss='categorical_crossentropy',
- optimizer=RMSPropOptimizer(learning_rate=0.001),
- metrics=set(0))
-
- with self.assertRaises(ValueError):
- model.compile(loss=None,
- optimizer='rms')
-
def test_model_methods_with_eager_tensors_multi_io(self):
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 301a6ca866..be9b0a21d7 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -45,6 +46,7 @@ except ImportError:
class TrainingTest(test.TestCase):
+ @tf_test_util.run_in_graph_and_eager_modes
def test_fit_on_arrays(self):
with self.test_session():
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -57,7 +59,7 @@ class TrainingTest(test.TestCase):
model = keras.models.Model([a, b], [d, e])
- optimizer = 'rmsprop'
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
loss_weights = [1., 0.5]
metrics = ['mae']
@@ -224,7 +226,7 @@ class TrainingTest(test.TestCase):
x = keras.layers.Input(shape=(3,), name='input_a')
y = keras.layers.Dense(4)(x)
model = keras.models.Model(x, y)
- model.compile(optimizer='rmsprop', loss='mse')
+ model.compile(optimizer, loss='mse')
# This will work
model.fit([input_a_np], output_d_np, epochs=1)
with self.assertRaises(ValueError):
@@ -240,6 +242,7 @@ class TrainingTest(test.TestCase):
batch_size=5,
verbose=2)
+ @tf_test_util.run_in_graph_and_eager_modes
def test_evaluate_predict_on_arrays(self):
with self.test_session():
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -252,7 +255,7 @@ class TrainingTest(test.TestCase):
model = keras.models.Model([a, b], [d, e])
- optimizer = 'rmsprop'
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
loss_weights = [1., 0.5]
metrics = ['mae']
@@ -322,6 +325,7 @@ class TrainingTest(test.TestCase):
})
self.assertEqual(len(out), 2)
+ @tf_test_util.run_in_graph_and_eager_modes
def test_invalid_loss_or_metrics(self):
num_classes = 5
train_samples = 1000
@@ -334,27 +338,29 @@ class TrainingTest(test.TestCase):
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(num_classes))
model.add(keras.layers.Activation('softmax'))
- model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, loss='categorical_crossentropy')
np.random.seed(1337)
(x_train, y_train), (_, _) = testing_utils.get_test_data(
train_samples=train_samples,
test_samples=test_samples,
input_shape=(input_dim,),
num_classes=num_classes)
- with self.assertRaises(ValueError):
- model.fit(x_train, y_train)
with self.assertRaises(ValueError):
model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
with self.assertRaises(TypeError):
- model.compile(loss='categorical_crossentropy',
- optimizer='rmsprop',
- metrics=set(0))
+ model.compile(
+ optimizer, loss='categorical_crossentropy', metrics=set(0))
- with self.assertRaises(ValueError):
- model.compile(loss=None,
- optimizer='rmsprop')
+ if not context.executing_eagerly():
+ # TODO(psv): Investigate these use cases in eager mode.
+ with self.assertRaises(ValueError):
+ model.fit(x_train, y_train)
+
+ with self.assertRaises(ValueError):
+ model.compile(optimizer, loss=None)
def test_training_on_sparse_data_with_dense_placeholders(self):
if scipy_sparse is None:
@@ -731,6 +737,54 @@ class LossWeightingTest(test.TestCase):
model.fit(x_np, [y_np, y_np], epochs=1,
sample_weight={'1': bad_w_np})
+ def test_default_sample_weight(self):
+ """Verifies that fit works without having to set sample_weight."""
+
+ num_classes = 5
+ input_dim = 5
+ timesteps = 3
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(num_classes),
+ input_shape=(timesteps, input_dim)))
+
+ x = np.random.random((10, timesteps, input_dim))
+ y = np.random.random((10, timesteps, num_classes))
+
+ # sample_weight_mode is a list and mode value is None
+ model.compile(loss='mse', optimizer='rmsprop', sample_weight_mode=[None])
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a list and mode value is `temporal`
+ model.compile(
+ loss='mse', optimizer='rmsprop', sample_weight_mode=['temporal'])
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a dict and mode value is None
+ model.compile(
+ loss='mse',
+ optimizer='rmsprop',
+ sample_weight_mode={'time_distributed': None})
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a dict and mode value is `temporal`
+ model.compile(
+ loss='mse',
+ optimizer='rmsprop',
+ sample_weight_mode={'time_distributed': 'temporal'})
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a not a list/dict and mode value is None
+ model.compile(loss='mse', optimizer='rmsprop', sample_weight_mode=None)
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a not a list/dict and mode value is `temporal`
+ model.compile(
+ loss='mse', optimizer='rmsprop', sample_weight_mode='temporal')
+ model.fit(x, y, epochs=1, batch_size=10)
+
class LossMaskingTest(test.TestCase):
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index dbbc87daf9..f2cd9c89da 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -26,10 +26,12 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -700,17 +702,16 @@ def populate_metric_names(model):
for i in range(len(model.outputs)):
metrics = model.nested_metrics[i]
for metric in metrics:
- base_metric_name = get_base_metric_name(metric)
+ base_metric_name = get_metric_name(metric)
add_metric_name(model, base_metric_name, i)
-def get_base_metric_name(metric, weighted=False):
- """Returns the metric name given the metric function.
+def get_metric_name(metric, weighted=False):
+ """Returns the metric name corresponding to the given metric input.
Arguments:
metric: Metric function name or reference.
- weighted: Boolean indicating if the metric for which we are adding
- names is weighted.
+ weighted: Boolean indicating if the given metric is weighted.
Returns:
a metric name.
@@ -734,6 +735,36 @@ def get_base_metric_name(metric, weighted=False):
return metric_name
+def get_metric_function(metric, output_shape=None, loss_fn=None):
+ """Returns the metric function corresponding to the given metric input.
+
+ Arguments:
+ metric: Metric function name or reference.
+ output_shape: The shape of the output that this metric
+ will be calculated for.
+ loss_fn: The loss function used.
+
+ Returns:
+ The metric function.
+ """
+ if metric in ['accuracy', 'acc']:
+ if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
+ return metrics_module.binary_accuracy # case: binary accuracy
+ elif loss_fn == losses.sparse_categorical_crossentropy:
+ # case: categorical accuracy with sparse targets
+ return metrics_module.sparse_categorical_accuracy
+ return metrics_module.categorical_accuracy # case: categorical accuracy
+ elif metric in ['crossentropy', 'ce']:
+ if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
+ return metrics_module.binary_crossentropy # case: binary cross-entropy
+ elif loss_fn == losses.sparse_categorical_crossentropy:
+ # case: categorical cross-entropy with sparse targets
+ return metrics_module.sparse_categorical_crossentropy
+ # case: categorical cross-entropy
+ return metrics_module.categorical_crossentropy
+ return metrics_module.get(metric)
+
+
def add_metric_name(model, metric_name, index):
"""Makes the metric name unique and adds it to the model's metric name list.
@@ -856,3 +887,25 @@ def cast_if_floating_dtype(x):
for val in x
]
return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
+
+
+def get_output_sample_weight_and_mode(skip_target_weighing_indices,
+ sample_weight_mode, output_name,
+ output_index):
+ """Returns the sample weight and weight mode for a single output."""
+ if output_index in skip_target_weighing_indices:
+ return None, None
+
+ if sample_weight_mode == 'temporal':
+ default_value = [[1.]]
+ shape = [None, None]
+ mode = 'temporal'
+ else:
+ default_value = [1.]
+ shape = [None]
+ mode = None
+ weight = array_ops.placeholder_with_default(
+ constant_op.constant(default_value, dtype=K.floatx()),
+ shape=shape,
+ name=output_name + '_sample_weights')
+ return weight, mode
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 5fbc191e78..3a153573f8 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -180,9 +180,6 @@ def get_nested_model_3(input_dim, num_classes):
x = self.dense2(x)
return self.bn(x)
- def compute_output_shape(self, input_shape):
- return tensor_shape.TensorShape((input_shape[0], 5))
-
test_model = Inner()
x = test_model(x)
outputs = keras.layers.Dense(num_classes)(x)
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 68873df97e..b567b71424 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -734,11 +734,11 @@ class ControlFlowTest(test.TestCase):
def body_fn(i):
with ops.control_dependencies([increment]):
- return i + i
+ return i + 1
- result = control_flow_ops.while_loop(cond=lambda i: i < 1,
+ result = control_flow_ops.while_loop(cond=lambda i: i < 2,
body=body_fn, loop_vars=[1])
- result.eval()
+ self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
def testWhileExternalControlDependenciesNoInput(self):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 474d06b8f3..00de94f004 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1706,7 +1706,7 @@ class SeparableConv2DTest(test.TestCase):
def testSeparableConv2D(self):
self._testSeparableConv2D("NHWC")
- def testSeparableConv2DNCHW(self):
+ def disabledtestSeparableConv2DNCHW(self):
if not test.is_gpu_available():
return
self._testSeparableConv2D("NCHW")
diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
index 510daf79dc..66b3e0f22f 100644
--- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
@@ -110,7 +110,8 @@ class DecodeJpegBenchmark(test.Benchmark):
start_time = time.time()
for _ in xrange(num_iters):
sess.run(r)
- return time.time() - start_time
+ end_time = time.time()
+ return end_time - start_time
def benchmarkDecodeJpegSmall(self):
"""Evaluate single DecodeImageOp for small size image."""
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index e4b5c3832a..0ef6a95cfc 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -24,13 +24,42 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class RandomNormalTest(test.TestCase):
+class RandomOpTestCommon(test.TestCase):
+
+ # Checks that executing the same rng_func multiple times rarely produces the
+ # same result.
+ def _testSingleSessionNotConstant(self,
+ rng_func,
+ num,
+ dtype,
+ min_or_mean,
+ max_or_stddev,
+ use_gpu,
+ op_seed=None,
+ graph_seed=None):
+ with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess:
+ if graph_seed is not None:
+ random_seed.set_random_seed(graph_seed)
+ x = rng_func([num], min_or_mean, max_or_stddev, dtype=dtype, seed=op_seed)
+
+ y = sess.run(x)
+ z = sess.run(x)
+ w = sess.run(x)
+
+ # We use exact equality here. If the random-number generator is producing
+ # the same output, all three outputs will be bitwise identical.
+ self.assertTrue((not np.array_equal(y, z)) or
+ (not np.array_equal(z, w)) or (not np.array_equal(y, w)))
+
+
+class RandomNormalTest(RandomOpTestCommon):
def _Sampler(self, num, mu, sigma, dtype, use_gpu, seed=None):
@@ -90,6 +119,36 @@ class RandomNormalTest(test.TestCase):
diff = rnd2 - rnd1
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
+ def testSingleSessionNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+ self._testSingleSessionNotConstant(
+ random_ops.random_normal, 100, dt, 0.0, 1.0, use_gpu=use_gpu)
+
+ def testSingleSessionOpSeedNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+ self._testSingleSessionNotConstant(
+ random_ops.random_normal,
+ 100,
+ dt,
+ 0.0,
+ 1.0,
+ use_gpu=use_gpu,
+ op_seed=1345)
+
+ def testSingleSessionGraphSeedNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+ self._testSingleSessionNotConstant(
+ random_ops.random_normal,
+ 100,
+ dt,
+ 0.0,
+ 1.0,
+ use_gpu=use_gpu,
+ graph_seed=965)
+
class TruncatedNormalTest(test.TestCase):
@@ -187,7 +246,7 @@ class TruncatedNormalTest(test.TestCase):
self.assertAllEqual(rnd1, rnd2)
-class RandomUniformTest(test.TestCase):
+class RandomUniformTest(RandomOpTestCommon):
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
@@ -291,6 +350,39 @@ class RandomUniformTest(test.TestCase):
diff = (rnd2 - rnd1).eval()
self.assertTrue(np.linalg.norm(diff) > 0.1)
+ def testSingleSessionNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64):
+ self._testSingleSessionNotConstant(
+ random_ops.random_uniform, 100, dt, 0, 17, use_gpu=use_gpu)
+
+ def testSingleSessionOpSeedNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64):
+ self._testSingleSessionNotConstant(
+ random_ops.random_uniform,
+ 100,
+ dt,
+ 10,
+ 20,
+ use_gpu=use_gpu,
+ op_seed=1345)
+
+ def testSingleSessionGraphSeedNotConstant(self):
+ for use_gpu in [False, True]:
+ for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64):
+ self._testSingleSessionNotConstant(
+ random_ops.random_uniform,
+ 100,
+ dt,
+ 20,
+ 200,
+ use_gpu=use_gpu,
+ graph_seed=965)
+
class RandomShapeTest(test.TestCase):
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index 868a4f6b84..f7cbfe0312 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -37,8 +37,19 @@ from tensorflow.python.training import saver
class PruningMode(object):
+ """Class for working with Pruning modes."""
NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
+ _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING}
+
+ @classmethod
+ def from_str(cls, mode):
+ if mode in cls._map:
+ return cls._map[mode]
+ else:
+ raise ValueError('pruning_mode mode must be one of: {}'.format(', '.join(
+ sorted(cls._map))))
+
class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for TreeEnsemble."""
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index 5cd0cb34de..44c5c050c0 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -58,12 +58,14 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
with ops.name_scope(name) as scope:
# Identify if there is a caller device, & get the innermost if possible.
- device_stack = ops.get_default_graph()._device_function_stack
- caller_device = device_stack[-1] if device_stack else None
+ # pylint: disable=protected-access
+ device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
+ caller_device = device_funcs[-1] if device_funcs else None
caller_colocation_stack = ops.get_default_graph()._colocation_stack
caller_container = ops.get_default_graph()._container
caller_collection_ref = ops.get_default_graph()._collections
+ # pylint: enable=protected-access
func_name_prefix = scope.replace("/", "_")
@@ -106,7 +108,7 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
false_graph.outputs.extend(extra_false_outputs)
# Create the If op.
- tensors = gen_functional_ops._if(
+ tensors = gen_functional_ops._if( # pylint: disable=protected-access
pred, cond_inputs, [t.dtype for t in true_graph.outputs],
_create_new_tf_function(true_graph),
_create_new_tf_function(false_graph),
@@ -125,8 +127,10 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
# TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
if_op = tensors[0].op
if not control_flow_util.IsInXLAContext(if_op):
+ # pylint: disable=protected-access
if_op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
return tensors[:num_cond_outputs]
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 9440bab9ee..855a4d0c33 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -2110,6 +2111,64 @@ def non_max_suppression(boxes,
iou_threshold, score_threshold)
+@tf_export('image.non_max_suppression_padded')
+def non_max_suppression_padded(boxes,
+ scores,
+ max_output_size,
+ iou_threshold=0.5,
+ score_threshold=float('-inf'),
+ pad_to_max_output_size=False,
+ name=None):
+ """Greedily selects a subset of bounding boxes in descending order of score.
+
+ Performs algorithmically equivalent operation to tf.image.non_max_suppression,
+ with the addition of an optional parameter which zero-pads the output to
+ be of size `max_output_size`.
+ The output of this operation is a tuple containing the set of integers
+ indexing into the input collection of bounding boxes representing the selected
+ boxes and the number of valid indices in the index set. The bounding box
+ coordinates corresponding to the selected indices can then be obtained using
+ the `tf.slice` and `tf.gather` operations. For example:
+ selected_indices_padded, num_valid = tf.image.non_max_suppression_padded(
+ boxes, scores, max_output_size, iou_threshold,
+ score_threshold, pad_to_max_output_size=True)
+ selected_indices = tf.slice(
+ selected_indices_padded, tf.constant([0]), num_valid)
+ selected_boxes = tf.gather(boxes, selected_indices)
+
+ Args:
+ boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
+ scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
+ score corresponding to each box (each row of boxes).
+ max_output_size: A scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ iou_threshold: A float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+ score_threshold: A float representing the threshold for deciding when to
+ remove boxes based on score.
+ pad_to_max_output_size: bool. If True, size of `selected_indices` output
+ is padded to `max_output_size`.
+ name: A name for the operation (optional).
+
+ Returns:
+ selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
+ selected indices from the boxes tensor, where `M <= max_output_size`.
+ valid_outputs: A scalar integer `Tensor` denoting how many elements in
+ `selected_indices` are valid. Valid elements occur first, then padding.
+ """
+ with ops.name_scope(name, 'non_max_suppression_padded'):
+ iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
+ score_threshold = ops.convert_to_tensor(
+ score_threshold, name='score_threshold')
+ if compat.forward_compatible(2018, 8, 7) or pad_to_max_output_size:
+ return gen_image_ops.non_max_suppression_v4(
+ boxes, scores, max_output_size, iou_threshold, score_threshold,
+ pad_to_max_output_size)
+ else:
+ return gen_image_ops.non_max_suppression_v3(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+
+
@tf_export('image.non_max_suppression_overlaps')
def non_max_suppression_with_overlaps(overlaps,
scores,
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 7096e0dd84..7b6ab20975 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -432,9 +432,15 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
return array_ops.reverse(input_, axis=[seq_axis])
with vs.variable_scope("bw") as bw_scope:
- inputs_reverse = _reverse(
- inputs, seq_lengths=sequence_length,
- seq_axis=time_axis, batch_axis=batch_axis)
+
+ def _map_reverse(inp):
+ return _reverse(
+ inp,
+ seq_lengths=sequence_length,
+ seq_axis=time_axis,
+ batch_axis=batch_axis)
+
+ inputs_reverse = nest.map_structure(_map_reverse, inputs)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index 61c6ffbd0d..cb251f08bb 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -60,6 +60,10 @@ SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"
tf_export("saved_model.constants.SAVED_MODEL_FILENAME_PBTXT").export_constant(
__name__, "SAVED_MODEL_FILENAME_PBTXT")
+# File name for json format of SavedModel.
+# Not exported while keras_saved_model is in contrib.
+SAVED_MODEL_FILENAME_JSON = "saved_model.json"
+
# Subdirectory name containing the variables/checkpoint files.
VARIABLES_DIRECTORY = "variables"
tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
@@ -69,5 +73,3 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
VARIABLES_FILENAME = "variables"
tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant(
__name__, "VARIABLES_FILENAME")
-
-
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index aca084fc91..60e96ee947 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -325,7 +325,7 @@ class FileWriter(SummaryToEventTransformer):
```
The `session` argument to the constructor makes the returned `FileWriter` a
- a compatibility layer over new graph-based summaries (`tf.contrib.summary`).
+ compatibility layer over new graph-based summaries (`tf.contrib.summary`).
Crucially, this means the underlying writer resource and events file will
be shared with any other `FileWriter` using the same `session` and `logdir`,
and with any `tf.contrib.summary.SummaryWriter` in this session using the
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index e9f1def48c..4349699a94 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -38,6 +38,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import re
import sys
from google.protobuf import text_format
@@ -116,16 +117,43 @@ def freeze_graph_with_def_protos(input_graph_def,
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
+
+ # List of all partition variables. Because the condition is heuristic
+ # based, the list could include false positives.
+ all_parition_variable_names = [
+ tensor.name.split(":")[0]
+ for op in sess.graph.get_operations()
+ for tensor in op.values()
+ if re.search(r"/part_\d+/", tensor.name)
+ ]
+ has_partition_var = False
+
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
+ if any(key in name for name in all_parition_variable_names):
+ has_partition_var = True
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(
- var_list=var_list, write_version=checkpoint_version)
+
+ try:
+ saver = saver_lib.Saver(
+ var_list=var_list, write_version=checkpoint_version)
+ except TypeError as e:
+ # `var_list` is required to be a map of variable names to Variable
+ # tensors. Partition variables are Identity tensors that cannot be
+ # handled by Saver.
+ if has_partition_var:
+ print("Models containing partition variables cannot be converted "
+ "from checkpoint files. Please pass in a SavedModel using "
+ "the flag --input_saved_model_dir.")
+ return -1
+ else:
+ raise e
+
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index 91f0061ebc..e38945fabc 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import re
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2
@@ -31,7 +32,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
@@ -262,6 +266,69 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output = sess.run(output_node, feed_dict={input_node: [example]})
self.assertNear(feature_value, output, 0.00001)
+ def testSinglePartitionedVariable(self):
+ """Ensures partitioned variables fail cleanly with freeze graph."""
+ checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
+ checkpoint_state_name = "checkpoint_state"
+ input_graph_name = "input_graph.pb"
+ output_graph_name = "output_graph.pb"
+
+ # Create a graph with partition variables. When weights are partitioned into
+ # a single partition, the weights variable is followed by a identity ->
+ # identity (an additional identity node).
+ partitioner = partitioned_variables.fixed_size_partitioner(1)
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope("part", partitioner=partitioner):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros(
+ (batch_size, height, width, depth), name="input1")
+ input2 = array_ops.zeros(
+ (batch_size, height, width, depth), name="input2")
+
+ num_nodes = depth
+ filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
+ filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
+ conv = nn.conv2d(
+ input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
+ node = math_ops.add(conv, input2, name="test/add")
+ node = nn.relu6(node, name="test/relu6")
+
+ # Save graph and checkpoints.
+ sess = session.Session()
+ sess.run(variables.global_variables_initializer())
+
+ saver = saver_lib.Saver()
+ checkpoint_path = saver.save(
+ sess,
+ checkpoint_prefix,
+ global_step=0,
+ latest_filename=checkpoint_state_name)
+ graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
+
+ # Ensure this graph has partition variables.
+ self.assertTrue([
+ tensor.name.split(":")[0]
+ for op in sess.graph.get_operations()
+ for tensor in op.values()
+ if re.search(r"/part_\d+/", tensor.name)
+ ])
+
+ # Test freezing graph doesn't make it crash.
+ output_node_names = "save/restore_all"
+ output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
+
+ return_value = freeze_graph.freeze_graph_with_def_protos(
+ input_graph_def=sess.graph_def,
+ input_saver_def=None,
+ input_checkpoint=checkpoint_path,
+ output_node_names=output_node_names,
+ restore_op_name="save/restore_all", # default value
+ filename_tensor_name="save/Const:0", # default value
+ output_graph=output_graph_path,
+ clear_devices=False,
+ initializer_nodes="")
+ self.assertTrue(return_value, -1)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 883f4fd910..a052081630 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -24,7 +24,6 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -308,32 +307,19 @@ def _set_checkpoint_initializer(variable,
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
- # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
- if resource_variable_ops.is_resource_variable(variable):
- init_op = variable.assign(restore_op, read_value=False)
- else:
- init_op = state_ops.assign(variable, restore_op)
+ names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
+ saveable_objects = []
+ for name, op in names_to_saveables.items():
+ for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
+ saveable_objects.append(s)
+
+ assert len(saveable_objects) == 1 # Should be only one variable.
+ init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
# pylint:disable=protected-access
- # We need special handling for `DistributedVariable`s as they contain
- # mutliple actual variables. `assign` on a `DistributedVariable` returns a
- # combined `init_op` which contains initializers for all the contained
- # variables. We then set each underlying variable's `_initializer_op` using
- # the corresponding `init_op`.
- # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class
- # moves out of contrib.
- if any(base.__name__ == "DistributedVariable"
- for base in variable.__class__.__bases__):
- assert distribute_lib.get_cross_tower_context()
- assert hasattr(variable, "_index")
- for (d, v) in six.iteritems(variable._index):
- v._initializer_op = init_op._index[d]
- restore_op.set_shape(v.shape)
- v._initial_value = restore_op
- else:
- variable._initializer_op = init_op
- restore_op.set_shape(variable.shape)
- variable._initial_value = restore_op
+ variable._initializer_op = init_op
+ restore_op.set_shape(variable.shape)
+ variable._initial_value = restore_op
# pylint:enable=protected-access
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 4e08a1c859..1c1f126ce9 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -386,7 +386,9 @@ class CheckpointsTest(test.TestCase):
op for op in g.get_operations()
if (op.name.startswith("init_from_checkpoint/") and
not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
- ) and op.type != "AssignVariableOp")
+ ) and
+ op.type != "AssignVariableOp" and
+ op.type != "Identity")
]
self.assertEqual(ops_in_init_from_checkpoint_scope, [])
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index f0703c8af4..66837ee52f 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -144,7 +144,7 @@ class _CheckpointPosition(object):
# process deferred restorations for it and its dependencies.
restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
if restore_ops:
- self._checkpoint.restore_ops.extend(restore_ops)
+ self._checkpoint.new_restore_ops(restore_ops)
def bind_object(self, checkpointable):
"""Set a checkpoint<->object correspondence and process slot variables.
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 5d26a817d4..664b2348c0 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -101,6 +101,7 @@ class _CheckpointRestoreCoordinator(object):
# this checkpoint.
self.restore_ops = []
self.restore_ops_by_name = {}
+ self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
# regular variables have already been created, and only for optimizer
@@ -121,6 +122,11 @@ class _CheckpointRestoreCoordinator(object):
slot_variable_id=slot_reference.slot_variable_node_id,
slot_name=slot_reference.slot_name))
+ def new_restore_ops(self, new_ops):
+ self.restore_ops.extend(new_ops)
+ if self.new_restore_ops_callback:
+ self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
+
class _NameBasedRestoreCoordinator(object):
"""Keeps the status of a name-based checkpoint restore."""
@@ -821,6 +827,31 @@ class _LoadStatus(object):
pass
+def streaming_restore(status, session=None):
+ """When graph building, runs restore ops as soon as they come in.
+
+ Args:
+ status: A _LoadStatus objects from an object-based saver's
+ restore(). Streaming restore from name-based checkpoints is not currently
+ supported.
+ session: A session to run new restore ops in.
+ """
+ if context.executing_eagerly():
+ # Streaming restore is the default/only behavior when executing eagerly.
+ return
+ if session is None:
+ session = ops.get_default_session()
+ if isinstance(status, NameBasedSaverStatus):
+ raise NotImplementedError(
+ "Streaming restore not supported from name-based checkpoints. File a "
+ "feature request if this limitation bothers you.")
+ status.run_restore_ops(session=session)
+ # pylint: disable=protected-access
+ status._checkpoint.new_restore_ops_callback = (
+ lambda ops: session.run(ops, feed_dict=status._feed_dict))
+ # pylint: enable=protected-access
+
+
class CheckpointLoadStatus(_LoadStatus):
"""Checks the status of checkpoint loading and manages restore ops.
@@ -992,11 +1023,13 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = (
"one this message is coming from) and use that checkpoint in the future.")
-@deprecation.deprecated(
- date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
class NameBasedSaverStatus(_LoadStatus):
"""Status for loading a name-based training checkpoint."""
+ # Ideally this deprecation decorator would be on the class, but that
+ # interferes with isinstance checks.
+ @deprecation.deprecated(
+ date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
def __init__(self, checkpoint, root_checkpointable):
self._checkpoint = checkpoint
self._root_checkpointable = root_checkpointable
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index ecce8ae6bd..204e81dda0 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -786,6 +786,37 @@ class SaverTest(test.TestCase):
self.assertEqual(20.0, v1.eval())
save.save(sess, save_path)
+ # Test restoring large tensors (triggers a thread pool)
+ def testRestoreLargeTensors(self):
+ save_dir = self.get_temp_dir()
+ def _model():
+ small_v = [variable_scope.get_variable(
+ "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)]
+ large_v = [variable_scope.get_variable(
+ "large%d" % i, shape=[32000, 1000], use_resource=True)
+ for i in range(3)]
+ return small_v + large_v
+
+ save_graph = ops_lib.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+ orig_vars = _model()
+ sess.run(variables.global_variables_initializer())
+ save = saver_module.Saver(max_to_keep=1)
+ variables.global_variables_initializer().run()
+ save.save(sess, save_dir)
+ orig_vals = sess.run(orig_vars)
+
+ restore_graph = ops_lib.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as sess:
+ restored_vars = _model()
+ save = saver_module.Saver(max_to_keep=1)
+ save.restore(sess, save_dir)
+ restored_vals = sess.run(restored_vars)
+
+ for orig, restored in zip(orig_vals, restored_vals):
+ self.assertAllEqual(orig, restored)
+
class SaveRestoreShardedTest(test.TestCase):
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 7bbbde3cd2..4e9b07e20a 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import functools
+import six
+
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -55,3 +57,36 @@ def fn_args(fn):
if _is_bounded_method(fn):
args.remove('self')
return tuple(args)
+
+
+def get_func_name(func):
+ """Returns name of passed callable."""
+ _, func = tf_decorator.unwrap(func)
+ if callable(func):
+ if tf_inspect.isfunction(func):
+ return func.__name__
+ elif tf_inspect.ismethod(func):
+ return '%s.%s' % (six.get_method_self(func).__class__.__name__,
+ six.get_method_function(func).__name__)
+ else: # Probably a class instance with __call__
+ return str(type(func))
+ else:
+ raise ValueError('Argument must be callable')
+
+
+def get_func_code(func):
+ """Returns func_code of passed callable, or None if not available."""
+ _, func = tf_decorator.unwrap(func)
+ if callable(func):
+ if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
+ return six.get_function_code(func)
+ # Since the object is not a function or method, but is a callable, we will
+ # try to access the __call__method as a function. This works with callable
+ # classes but fails with functool.partial objects despite their __call__
+ # attribute.
+ try:
+ return six.get_function_code(func.__call__)
+ except AttributeError:
+ return None
+ else:
+ raise ValueError('Argument must be callable')
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index e78cf6a5b0..1588328c26 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -24,6 +24,16 @@ from tensorflow.python.platform import test
from tensorflow.python.util import function_utils
+def silly_example_function():
+ pass
+
+
+class SillyCallableClass(object):
+
+ def __call__(self):
+ pass
+
+
class FnArgsTest(test.TestCase):
def test_simple_function(self):
@@ -124,5 +134,73 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
+
+class GetFuncNameTest(test.TestCase):
+
+ def testWithSimpleFunction(self):
+ self.assertEqual(
+ 'silly_example_function',
+ function_utils.get_func_name(silly_example_function))
+
+ def testWithClassMethod(self):
+ self.assertEqual(
+ 'GetFuncNameTest.testWithClassMethod',
+ function_utils.get_func_name(self.testWithClassMethod))
+
+ def testWithCallableClass(self):
+ callable_instance = SillyCallableClass()
+ self.assertRegexpMatches(
+ function_utils.get_func_name(callable_instance),
+ '<.*SillyCallableClass.*>')
+
+ def testWithFunctoolsPartial(self):
+ partial = functools.partial(silly_example_function)
+ self.assertRegexpMatches(
+ function_utils.get_func_name(partial),
+ '<.*functools.partial.*>')
+
+ def testWithLambda(self):
+ anon_fn = lambda x: x
+ self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn))
+
+ def testRaisesWithNonCallableObject(self):
+ with self.assertRaises(ValueError):
+ function_utils.get_func_name(None)
+
+
+class GetFuncCodeTest(test.TestCase):
+
+ def testWithSimpleFunction(self):
+ code = function_utils.get_func_code(silly_example_function)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithClassMethod(self):
+ code = function_utils.get_func_code(self.testWithClassMethod)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithCallableClass(self):
+ callable_instance = SillyCallableClass()
+ code = function_utils.get_func_code(callable_instance)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithLambda(self):
+ anon_fn = lambda x: x
+ code = function_utils.get_func_code(anon_fn)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithFunctoolsPartial(self):
+ partial = functools.partial(silly_example_function)
+ code = function_utils.get_func_code(partial)
+ self.assertIsNone(code)
+
+ def testRaisesWithNonCallableObject(self):
+ with self.assertRaises(ValueError):
+ function_utils.get_func_code(None)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 5aac559b9b..faae0d89c3 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -377,6 +377,62 @@ def map_structure(func, *structure, **check_types_dict):
structure[0], [func(*x) for x in entries])
+def map_structure_with_paths(func, *structure, **kwargs):
+ """Applies `func` to each entry in `structure` and returns a new structure.
+
+ Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
+ `structure[i]` and `path` is the common path to x[i] in the structures. All
+ structures in `structure` must have the same arity, and the return value will
+ contain the results in the same structure. Special kwarg `check_types`
+ determines whether the types of iterables within the structure must be the
+ same-- see **kwargs definition below.
+
+ Args:
+ func: A callable with the signature func(path, *values, **kwargs) that is
+ evaluated on the leaves of the structure.
+ *structure: A variable number of compatible structures to process.
+ **kwargs: Optional kwargs to be passed through to func. Special kwarg
+ `check_types` is not passed to func, but instead determines whether the
+ types of iterables within the structures have to be same (e.g.,
+ `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
+ default, the types must match. To allow iteration over structures of
+ different types (but common arity), set this kwarg to `False`.
+
+ Returns:
+ A structure of the same form as the input structures whose leaves are the
+ result of evaluating func on corresponding leaves of the input structures.
+
+ Raises:
+ TypeError: If `func` is not callable or if the structures do not match
+ each other by depth tree.
+ TypeError: If `check_types` is not `False` and the two structures differ in
+ the type of sequence in any of their substructures.
+ ValueError: If no structures are provided.
+ """
+ if not callable(func):
+ raise TypeError("func must be callable, got: %s" % func)
+ if not structure:
+ raise ValueError("Must provide at least one structure")
+
+ check_types = kwargs.pop("check_types", True)
+ for other in structure[1:]:
+ assert_same_structure(structure[0], other, check_types=check_types)
+
+ # First set paths_and_values to:
+ # [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]]
+ paths_and_values = [flatten_with_joined_string_paths(s) for s in structure]
+
+ # Now zip(*paths_and_values) would be:
+ # [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))]
+ # so grouped_by_path is set to:
+ # [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
+ # Note that p1i, ... pmi must all be equal since the structures are the same.
+ grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
+
+ return pack_sequence_as(structure[0], [
+ func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
+
+
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 26c6ea4b01..fd75c6885a 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -746,6 +746,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
self.assertEqual(
list(nest.flatten_with_joined_string_paths(inputs)), expected)
+ @parameterized.named_parameters(
+ ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
+ ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
+ {"a": ("a", 4), "b": ("b", 6)}),
+ ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
+ ("nested",
+ {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
+ {"a": [("a/0", 10), ("a/1", 12)],
+ "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
+ def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
+ def format_sum(path, *values):
+ return (path, sum(values))
+ result = nest.map_structure_with_paths(format_sum, s1, s2,
+ check_types=check_types)
+ self.assertEqual(expected, result)
+
+ @parameterized.named_parameters(
+ ("tuples", (1, 2), (3, 4, 5), ValueError),
+ ("dicts", {"a": 1}, {"b": 2}, ValueError),
+ ("mixed", (1, 2), [3, 4], TypeError),
+ ("nested",
+ {"a": [2, 3], "b": [1, 3]},
+ {"b": [5, 6, 7], "a": [8, 9]},
+ ValueError
+ ))
+ def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
+ with self.assertRaises(error_type):
+ nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)
+
class NestBenchmark(test.Benchmark):
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 766a0dafb5..1c3940e92c 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -322,6 +322,7 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {}
port::Status CudnnSupport::Init() {
+ ScopedActivateExecutorContext context(parent_);
cudnnHandle_t cudnn_handle = nullptr;
auto status = cudnnCreate(&cudnn_handle);
if (status == CUDNN_STATUS_SUCCESS) {
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index d508f6594a..dbece3adf9 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -102,117 +102,16 @@ class CreatedContexts {
/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
// Formats CUresult to output prettified values into a log stream.
-// Error summaries taken from:
-// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc6c391505e117393cc2558fff6bfc2e9
-//
-// TODO(leary) switch to cuGetErrorName when updated cuda.h is available.
string ToString(CUresult result) {
-#define OSTREAM_CUDA_ERROR(__name) \
- case CUDA_ERROR_##__name: \
- return "CUDA_ERROR_" #__name;
-
-///////////////
-// NOTE: here we specify return code values outside of the enum explicitly
-// because our in-tree cuda.h is from the CUDA 5.5 SDK, but CUDA 6.0+ driver
-// libraries are deployed in the fleet these error codes are backwards
-// compatible, but if we see a "new" one, we want to be able to identify it in
-// the logs.
-//
-// Once we get a cuda.h that has cuGetErrorName (TODO is above) we can
-// eliminate this function and just rely on the driver to provide us these
-// strings.
-//
-// NOTE: "Must reboot all context" below is shorthand for, "must
-// destroy/recreate the offending context and any allocation which come from
-// it if you are to continue using CUDA."
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wswitch"
- switch (result) {
- OSTREAM_CUDA_ERROR(INVALID_VALUE)
- OSTREAM_CUDA_ERROR(OUT_OF_MEMORY)
- OSTREAM_CUDA_ERROR(NOT_INITIALIZED)
- OSTREAM_CUDA_ERROR(DEINITIALIZED)
- OSTREAM_CUDA_ERROR(NO_DEVICE)
- OSTREAM_CUDA_ERROR(INVALID_DEVICE)
- OSTREAM_CUDA_ERROR(INVALID_IMAGE)
- OSTREAM_CUDA_ERROR(INVALID_CONTEXT)
- OSTREAM_CUDA_ERROR(INVALID_HANDLE)
- OSTREAM_CUDA_ERROR(NOT_FOUND)
- OSTREAM_CUDA_ERROR(NOT_READY)
- OSTREAM_CUDA_ERROR(NO_BINARY_FOR_GPU)
-
- // Encountered an uncorrectable ECC error during execution.
- OSTREAM_CUDA_ERROR(ECC_UNCORRECTABLE)
-
- // Load/store on an invalid address. Must reboot all context.
- case 700:
- return "CUDA_ERROR_ILLEGAL_ADDRESS";
- // Passed too many / wrong arguments, too many threads for register count.
- case 701:
- return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES";
- // Kernel took too long to execute.
- case 702:
- return "CUDA_ERROR_LAUNCH_TIMEOUT";
- // Kernel launch uses an incompatible texturing mode.
- case 703:
- return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING";
- // Trying to re-enable peer access that already has it enabled.
- case 704:
- return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED";
- // Trying to disable peer access that has not yet been enabled.
- case 705:
- return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED";
- // Primary context for the specified device has already been initialized.
- case 708:
- return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE";
- // Context current to calling thread has been destroyed or is a primary
- // context that has not yet been initialized.
- case 709:
- return "CUDA_ERROR_CONTEXT_IS_DESTROYED";
- // Device-side assert triggered during kernel execution. Must reboot all
- // context.
- case 710:
- return "CUDA_ERROR_ASSERT";
- // Hardware resources to enable peer access have been exhausted.
- case 711:
- return "CUDA_ERROR_TOO_MANY_PEERS";
- // Memory range has already been registered.
- case 712:
- return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED";
- // Pointer does not correspond to any currently registered memory region.
- case 713:
- return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED";
- // Due to stack corruption or exceeding stack size limit. Must reboot all
- // context.
- case 714:
- return "CUDA_ERROR_HARDWARE_STACK_ERROR";
- case 715:
- return "CUDA_ERROR_ILLEGAL_INSTRUCTION";
- // Load/store on an unaligned memory address. Must reboot all context.
- case 716:
- return "CUDA_ERROR_MISALIGNED_ADDRESS";
- // Device instruction with specific address space given address not
- // belonging to allowed address space. Must reboot all context.
- case 717:
- return "CUDA_ERROR_INVALID_ADDRESS_SPACE";
- // Device program counter wrapped its address space. Must reboot all
- // context.
- case 718:
- return "CUDA_ERROR_INVALID_PC";
- // Exception on device while executing a kernel; e.g. deref invalid device
- // pointer, accessing OOB shared memory. Must reboot all context.
- case 719:
- return "CUDA_ERROR_LAUNCH_FAILED";
-
- OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE)
- OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED)
- OSTREAM_CUDA_ERROR(NOT_PERMITTED)
- OSTREAM_CUDA_ERROR(NOT_SUPPORTED)
- OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA.
- default:
- return port::StrCat("CUresult(", static_cast<int>(result), ")");
+ const char *error_name;
+ if (cuGetErrorName(result, &error_name)) {
+ return port::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
+ }
+ const char *error_string;
+ if (cuGetErrorString(result, &error_string)) {
+ return error_name;
}
-#pragma GCC diagnostic pop
+ return port::StrCat(error_name, ": ", error_string);
}
// Returns the current context and checks that it is in the set of CUDA contexts
diff --git a/tensorflow/stream_executor/module_spec.h b/tensorflow/stream_executor/module_spec.h
index 212ae7ba9c..75bdfed2d7 100644
--- a/tensorflow/stream_executor/module_spec.h
+++ b/tensorflow/stream_executor/module_spec.h
@@ -43,6 +43,7 @@ class MultiModuleLoaderSpec {
}
void AddCudaCubinInMemory(port::ArraySlice<const uint8> cubin_bytes) {
+ CHECK(!cubin_bytes.empty());
has_cuda_cubin_in_memory_ = true;
cuda_cubin_in_memory_ = cubin_bytes;
}
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 2c495c99e1..b0c061fd74 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -267,13 +267,13 @@ Stream::Stream(StreamExecutor *parent,
Stream::~Stream() {
VLOG_CALL();
- temporary_memory_manager_.ForceDeallocateAll();
// Ensure the stream is completed.
auto status = BlockHostUntilDone();
if (!status.ok()) {
LOG(WARNING) << "Error blocking host until done in stream destructor: "
<< status;
}
+ temporary_memory_manager_.ForceDeallocateAll();
if (allocated_) {
parent_->DeallocateStream(this);
@@ -1941,7 +1941,14 @@ void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (stream.first.get() == sub_stream) {
- stream.second = true;
+ // Streams have a monotonic state machine; if a stream
+ // encounters an error, it will remain in an error state
+ // forever. Only allow re-use of ok streams.
+ //
+ // TODO(toddw): Improve this mechanism, if necessary, to drop
+ // failed streams completely.
+ const bool ready_to_reuse = sub_stream->ok();
+ stream.second = ready_to_reuse;
return;
}
}
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 63d64947c8..706442a666 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -125,7 +125,7 @@ class Stream {
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
// Return the sub-stream back to the host stream so that it can be reused
- // later.
+ // later. Sub-streams that are !ok() will not be reused.
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
// Allocate temporary memories. The stream will deallocate them when blocked
diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc
new file mode 100644
index 0000000000..47dd675834
--- /dev/null
+++ b/tensorflow/stream_executor/stream_test.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/stream_executor.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace stream_executor {
+namespace {
+
+class StreamTest : public ::testing::Test {
+ protected:
+ std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+ Platform* platform =
+ MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie();
+ StreamExecutorConfig config(/*ordinal=*/0);
+ return platform->GetUncachedExecutor(config).ConsumeValueOrDie();
+ }
+};
+
+TEST_F(StreamTest, NoInitNotOk) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ EXPECT_FALSE(stream.ok());
+}
+
+TEST_F(StreamTest, InitOk) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+}
+
+TEST_F(StreamTest, OneSubStream) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get and return a sub-stream. Sub-streams are always initialized.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // Get and return another sub-stream.
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // The underlying sub-streams should be the same, since sub_stream1
+ // was returned before we tried to get sub_stream2.
+ EXPECT_EQ(sub_stream1, sub_stream2);
+}
+
+TEST_F(StreamTest, TwoSubStreams) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get two sub-streams.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying sub-streams should be different, since neither
+ // sub-stream has been returned.
+ EXPECT_NE(sub_stream1, sub_stream2);
+
+ // Return sub_stream1 and get sub_stream3, which should be the same.
+ stream.ReturnSubStream(sub_stream1);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream1, sub_stream3);
+ EXPECT_NE(sub_stream2, sub_stream3);
+
+ // Return sub_stream2 and get sub_stream4, which should be the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream4 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream4->ok());
+ EXPECT_EQ(sub_stream2, sub_stream4);
+ EXPECT_NE(sub_stream3, sub_stream4);
+}
+
+TEST_F(StreamTest, FailedSubStreamNotReused) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get a sub-stream.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+
+ // Force an error on the stream; here we call a method that requires
+ // DNN support, which we know the Host platform doesn't support.
+ sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(sub_stream1->ok());
+
+ // Return sub_stream1 and get sub_stream2.
+ stream.ReturnSubStream(sub_stream1);
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on sub_stream1, it will
+ // not be re-used. Sadly we can't just check:
+ // EXPECT_NE(sub_stream1, sub_stream2);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for sub_stream2 happens to reside in
+ // the same memory address as sub_stream1.
+ //
+ // The check that sub_stream2->ok() serves as a good-enough check.
+
+ // Return sub_stream2 and get sub_stream3. The previous error on
+ // sub_stream1 has no effect on these streams, and they are the
+ // same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream2, sub_stream3);
+}
+
+} // namespace
+} // namespace stream_executor
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 340d3f393c..58282ec1c7 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -9,6 +9,7 @@ load(
"tf_additional_grpc_deps_py",
"tf_additional_xla_deps_py",
"if_static",
+ "if_dynamic_kernels",
)
load(
"@local_config_tensorrt//:build_defs.bzl",
@@ -318,18 +319,36 @@ def tf_binary_additional_srcs():
clean_dep("//tensorflow:libtensorflow_framework.so"),
])
+
+# Helper functions to add kernel dependencies to tf binaries when using dynamic
+# kernel linking.
+def tf_binary_dynamic_kernel_dsos(kernels):
+ return if_dynamic_kernels(
+ extra_deps=["libtfkernel_%s.so" % clean_dep(k) for k in kernels],
+ otherwise=[])
+
+# Helper functions to add kernel dependencies to tf binaries when using static
+# kernel linking.
+def tf_binary_dynamic_kernel_deps(kernels):
+ return if_dynamic_kernels(
+ extra_deps=[],
+ otherwise=kernels)
+
def tf_cc_shared_object(
name,
srcs=[],
deps=[],
+ data=[],
linkopts=[],
framework_so=tf_binary_additional_srcs(),
+ kernels=[],
**kwargs):
native.cc_binary(
name=name,
srcs=srcs + framework_so,
- deps=deps,
+ deps=deps + tf_binary_dynamic_kernel_deps(kernels),
linkshared = 1,
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
linkopts=linkopts + _rpath_linkopts(name) + select({
clean_dep("//tensorflow:darwin"): [
"-Wl,-install_name,@rpath/" + name.split("/")[-1],
@@ -353,18 +372,21 @@ register_extension_info(
def tf_cc_binary(name,
srcs=[],
deps=[],
+ data=[],
linkopts=[],
copts=tf_copts(),
+ kernels=[],
**kwargs):
native.cc_binary(
name=name,
copts=copts,
srcs=srcs + tf_binary_additional_srcs(),
- deps=deps + if_mkl(
+ deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
[
"//third_party/mkl:intel_binary_blob",
],
),
+ data=data + tf_binary_dynamic_kernel_dsos(kernels),
linkopts=linkopts + _rpath_linkopts(name),
**kwargs)
@@ -549,9 +571,6 @@ def tf_gen_op_wrappers_cc(name,
# is invalid to specify both "hidden" and "op_whitelist".
# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
# specified ops.
-# gen_locally: if True, the genrule to generate the Python library will be run
-# without sandboxing. This would help when the genrule depends on symlinks
-# which may not be supported in the sandbox.
def tf_gen_op_wrapper_py(name,
out=None,
hidden=None,
@@ -562,8 +581,7 @@ def tf_gen_op_wrapper_py(name,
generated_target_name=None,
op_whitelist=[],
cc_linkopts=[],
- api_def_srcs=[],
- gen_locally=False):
+ api_def_srcs=[]):
if (hidden or hidden_file) and op_whitelist:
fail('Cannot pass specify both hidden and op_whitelist.')
@@ -618,7 +636,6 @@ def tf_gen_op_wrapper_py(name,
outs=[out],
srcs=api_def_srcs + [hidden_file],
tools=[tool_name] + tf_binary_additional_srcs(),
- local = (1 if gen_locally else 0),
cmd=("$(location " + tool_name + ") " + api_def_args_str +
" @$(location " + hidden_file + ") " +
("1" if require_shape_functions else "0") + " > $@"))
@@ -628,7 +645,6 @@ def tf_gen_op_wrapper_py(name,
outs=[out],
srcs=api_def_srcs,
tools=[tool_name] + tf_binary_additional_srcs(),
- local = (1 if gen_locally else 0),
cmd=("$(location " + tool_name + ") " + api_def_args_str + " " +
op_list_arg + " " +
("1" if require_shape_functions else "0") + " " +
@@ -658,11 +674,13 @@ def tf_gen_op_wrapper_py(name,
def tf_cc_test(name,
srcs,
deps,
+ data=[],
linkstatic=0,
extra_copts=[],
suffix="",
linkopts=[],
nocopts=None,
+ kernels=[],
**kwargs):
native.cc_test(
name="%s%s" % (name, suffix),
@@ -682,11 +700,12 @@ def tf_cc_test(name,
"-lm"
],
}) + linkopts + _rpath_linkopts(name),
- deps=deps + if_mkl(
+ deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
[
"//third_party/mkl:intel_binary_blob",
],
),
+ data=data + tf_binary_dynamic_kernel_dsos(kernels),
# Nested select() statements seem not to be supported when passed to
# linkstatic, and we already have a cuda select() passed in to this
# function.
@@ -787,6 +806,7 @@ def tf_cuda_only_cc_test(name,
size="medium",
linkstatic=0,
args=[],
+ kernels=[],
linkopts=[]):
native.cc_test(
name="%s%s" % (name, "_gpu"),
@@ -794,8 +814,8 @@ def tf_cuda_only_cc_test(name,
size=size,
args=args,
copts= _cuda_copts() + tf_copts(),
- data=data,
- deps=deps + if_cuda([
+ data=data + tf_binary_dynamic_kernel_dsos(kernels),
+ deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([
clean_dep("//tensorflow/core:cuda"),
clean_dep("//tensorflow/core:gpu_lib")]),
linkopts=if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name),
@@ -838,9 +858,11 @@ def tf_cc_tests(srcs,
def tf_cc_test_mkl(srcs,
deps,
name="",
+ data=[],
linkstatic=0,
tags=[],
size="medium",
+ kernels=[],
args=None):
# -fno-exceptions in nocopts breaks compilation if header modules are enabled.
disable_header_modules = ["-use_header_modules"]
@@ -861,11 +883,12 @@ def tf_cc_test_mkl(srcs,
"-lm"
],
}) + _rpath_linkopts(src_to_test_name(src)),
- deps=deps + if_mkl(
+ deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
[
"//third_party/mkl:intel_binary_blob",
],
),
+ data=data + tf_binary_dynamic_kernel_dsos(kernels),
linkstatic=linkstatic,
tags=tags,
size=size,
@@ -905,12 +928,13 @@ def tf_cuda_cc_tests(srcs,
def tf_java_test(name,
srcs=[],
deps=[],
+ kernels=[],
*args,
**kwargs):
native.java_test(
name=name,
srcs=srcs,
- deps=deps + tf_binary_additional_srcs(),
+ deps=deps + tf_binary_additional_srcs() + tf_binary_dynamic_kernel_dsos(kernels) + tf_binary_dynamic_kernel_deps(kernels),
*args,
**kwargs)
@@ -1084,6 +1108,15 @@ def tf_kernel_library(
deps=deps,
**kwargs)
+ # TODO(gunan): CUDA dependency not clear here. Fix it.
+ tf_cc_shared_object(
+ name="libtfkernel_%s.so" % name,
+ srcs=srcs + hdrs,
+ copts=copts,
+ deps=deps,
+ tags=["manual", "notap"])
+
+
register_extension_info(
extension_name = "tf_kernel_library",
label_regex_for_dep = "{extension_name}(_gpu)?",
@@ -1456,7 +1489,7 @@ def tf_py_wrap_cc(name,
srcs=srcs,
swig_includes=swig_includes,
deps=deps + extra_deps,
- toolchain_deps=["//tools/defaults:crosstool"],
+ toolchain_deps=["@bazel_tools//tools/cpp:current_cc_toolchain"],
module_name=module_name,
py_module_name=name)
vscriptname=name+"_versionscript"
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
index ef9fe096a1..eb41deee13 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
@@ -14,5 +14,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "executor_type"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
index eeef15515d..e565b903d2 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
@@ -137,6 +137,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "executor_type"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 9dbb5d16a4..c23b04b4ef 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 34a30c2874..6878d28fff 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 6ec3aba775..5c46dc5ee7 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -125,6 +125,10 @@ tf_module {
argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
+ name: "non_max_suppression_padded"
+ argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
+ }
+ member_method {
name: "pad_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh
index f6a50d3d4c..77265e0f50 100755
--- a/tensorflow/tools/ci_build/ci_build.sh
+++ b/tensorflow/tools/ci_build/ci_build.sh
@@ -115,6 +115,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]')
# Print arguments.
echo "WORKSPACE: ${WORKSPACE}"
+echo "CI_DOCKER_BUILD_EXTRA_PARAMS: ${CI_DOCKER_BUILD_EXTRA_PARAMS[*]}"
echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[*]}"
echo "COMMAND: ${COMMAND[*]}"
echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[*]}"
@@ -126,7 +127,7 @@ echo ""
# Build the docker container.
echo "Building container (${DOCKER_IMG_NAME})..."
-docker build -t ${DOCKER_IMG_NAME} \
+docker build -t ${DOCKER_IMG_NAME} ${CI_DOCKER_BUILD_EXTRA_PARAMS[@]} \
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
# Check docker build status
diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
deleted file mode 100644
index 6796ad70e5..0000000000
--- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
+++ /dev/null
@@ -1,83 +0,0 @@
-FROM tensorflow/tensorflow:latest-devel
-
-LABEL maintainer="Clayne Robison<clayne.b.robison@intel.com>"
-
-# These arguments are parameterized. Use --build-args to override.
-ARG TF_BRANCH=r1.9
-ARG WHL_DIR=/whl
-
-RUN apt-get update && apt-get install -y --no-install-recommends \
- golang \
- vim \
- emacs \
- && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists/*
-
-RUN pip --no-cache-dir install --upgrade \
- pip setuptools
-
-RUN pip --no-cache-dir install wheel
-
-# Download and build TensorFlow.
-WORKDIR /
-RUN rm -rf tensorflow && \
- git clone https://github.com/tensorflow/tensorflow.git && \
- cd tensorflow && \
- git checkout ${TF_BRANCH}
-WORKDIR /tensorflow
-
-# Configure the build for CPU with MKL by accepting default build options and
-# setting library locations
-ENV CI_BUILD_PYTHON=python \
- LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \
- PYTHON_BIN_PATH=/usr/bin/python \
- PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
- CC_OPT_FLAGS='-march=native' \
- TF_NEED_JEMALLOC=0 \
- TF_NEED_GCP=1 \
- TF_NEED_CUDA=0 \
- TF_NEED_HDFS=0 \
- TF_NEED_S3=1 \
- TF_NEED_OPENCL=0 \
- TF_NEED_GDR=0 \
- TF_ENABLE_XLA=0 \
- TF_NEED_VERBS=0 \
- TF_NEED_MPI=0
-RUN ./configure
-
-# Build and Install TensorFlow.
-# The 'mkl' option builds with Intel(R) Math Kernel Library (MKL), which detects
-# the platform it is currently running on and takes appropriately optimized
-# paths. The -march=native option is for code that is not in MKL, and assumes
-# this container will be run on the same architecture on which it is built.
-RUN LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \
- bazel build --config=mkl \
- --config="opt" \
- --copt="-march=broadwell" \
- --copt="-O3" \
- //tensorflow/tools/pip_package:build_pip_package && \
- mkdir ${WHL_DIR} && \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package ${WHL_DIR}
-
-# Clean up Bazel cache when done, but leave the whl.
-# This will upgrade the default Tensorflow version with the Intel MKL version
-RUN pip --no-cache-dir install --upgrade ${WHL_DIR}/tensorflow-*.whl && \
- rm -rf /root/.cache
-
-WORKDIR /root
-
-#add welcome message with instructions
-
-RUN echo '[ ! -z "$TERM" -a -r /etc/motd ] && cat /etc/issue && cat /etc/motd' \
- >> /etc/bash.bashrc \
- ; echo "\
-||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\
-| \n\
-| Docker container running Ubuntu \n\
-| with TensorFlow ${TF_BRANCH} optimized for CPU \n\
-| with Intel(R) MKL \n\
-| \n\
-||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\
-\n "\
- > /etc/motd
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 314169fc19..45b1abeb10 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -485,11 +485,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/a9364fc18506373b10922802983f76229cc1f371.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/a9364fc18506373b10922802983f76229cc1f371.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
],
- sha256 = "5d727fedfbb805a44a671db8f3fbaa09dbe5177a5c1cc0635fd61c324e6409f2",
- strip_prefix = "llvm-a9364fc18506373b10922802983f76229cc1f371",
+ sha256 = "c6cbb21acd46e3e00faa8c379595ecffb99ef77622da17f29371db2bfad1d3d3",
+ strip_prefix = "llvm-7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)